In [None]:
from Bio.PDB.PDBParser import PDBParser
from Bio.PDB.Polypeptide import three_to_one
from Bio.PDB.Polypeptide import is_aa
from Bio import pairwise2
from multiprocessing import Pool, cpu_count
from functools import partial
import scipy.cluster.hierarchy
import numpy as np
import sys
import argparse
import bisect
import re
import os
import fnmatch
import pickle
import collections
from rdkit.Chem import AllChem as Chem
from rdkit.Chem import AllChem
from rdkit.DataStructs import FingerprintSimilarity as fs
from rdkit.Chem.Fingerprints import FingerprintMols
import rdkit
import tqdm.auto as tqdm

In [None]:
import pandas as pd

In [None]:
f = open('PocketMatch_v2.0/PocketMatch_score.txt', 'r')
data = f.readlines()

In [None]:
final_data = []

In [None]:
keys = [] 
for row in tqdm.tqdm(data[1:]):
    v = row.split()
    keys.append(v[0])
    keys.append(v[1])
    
keys = list(set(keys))

In [None]:
key_index = {}
for key in tqdm.tqdm(keys):
    key_index[key] = keys.index(key)

In [None]:
distance_mat = np.zeros((len(keys), len(keys)))
distance_mat.shape

In [None]:
for row in tqdm.tqdm(data[1:]):
#     print(row[1])
    v = row.split()
    p1 = v[0]
    p2 = v[1]
    s = v[3]
#     break
#     print(s)
    
    distance_mat[key_index[p1]][key_index[p2]] = 1 - float(s)
    distance_mat[key_index[p2]][key_index[p1]] = 1 - float(s)


In [None]:
distance_mat = np.round(distance_mat, 3)

In [None]:
for i in tqdm.tqdm(range(len(distance_mat))):
    distance_mat[i][i] = 0

In [None]:
# distance_mat.dump('deeplytough_dist_mat.npy')

# with open('pocketmatch_dist_mat.pkl', 'wb') as f:
#     pickle.dump(distance_mat, f)

# with open('pocketmatch_keys.pkl', 'wb') as f:
#     pickle.dump(keys, f)
    
# with open('pocketmatch_dist_mat.pkl', 'rb') as f:
#     distance_mat = pickle.load(f)

# with open('pocketmatch_keys.pkl', 'rb') as f:
#     keys = pickle.load(f)

In [None]:
# for i in tqdm.tqdm(range(len(distance_mat))):
#     for j in range(len(distance_mat)):
#         distance_mat[i][j] = 1 - distance_mat[i][j]

In [None]:
def calcClusterGroups(dists, ligandsim, target_names, t, t2, ligandt, all_pairs):
    '''dists is a distance matrix (full) for target_names'''
    assigned = set()
    groups = []
    for i in range(dists.shape[0]):
        
        if i not in assigned:
            group = assignGroup(dists, ligandsim, t, t2,
                                ligandt, set([i]), target_names)
            groups.append(group)
            assigned.update(group)
    return [set(target_names[i] for i in g) for g in groups]

In [None]:
def assignGroup(dists, ligandsim, t, t2, ligandt, explore, names):
    '''group targets that are less than t away from each other and what's in explore'''
    group = set(explore)
    while explore:
        frontier = set()
        for i in explore:
            for j in range(dists.shape[1]):
                if j not in group:
                    # add to the group if protein is close by threshold t (these are distances - default 0.5)
                    # also add if the ligands are more similar (not distance) than ligandt and
                    # the protein is closer than t2 (default 0.8 - meaning more than 20% similar)
#                     print(dist[i][j])
                    if dists[i][j] < t: # or (ligandsim[i][j] > ligandt and dists[i][j] < t2):
                        group.add(j)
                        frontier.add(j)

        explore = frontier
    return group

In [None]:
# similarity = 0.95
threshold =  0.25 # the distances are ranging from 0 to 1.75

In [None]:
clusters = calcClusterGroups(distance_mat,
                             ligandsim=None,
                             target_names=keys,
                             t=threshold,
                             t2 = None, 
                             ligandt=None,
                             all_pairs=None
                            )
og_clusters = clusters

In [None]:
pl = []
for cluster in clusters:
    pl.append(len(cluster))
pl.sort(reverse=True)
print(pl)

In [None]:
def get_target_lines(vec):
    ret = collections.defaultdict(list)
    
    for line in vec:
        targ = line
        ret[targ].append([line, 'smile'])

    return ret

In [None]:
target_lines = get_target_lines(keys)

In [None]:
def createFolds(cluster_groups, numfolds, target_lines, randomize):
    '''split target clusters into numfolds folds with balanced num poses per fold
       If randomize, will balance less well.
    '''
    folds = [[] for _ in range(numfolds)]
    fold_numposes = [0]*numfolds
    group_numposes = [0]*len(cluster_groups)
    foldmap = {}
    for i, group in enumerate(cluster_groups):
        # count num poses per group
        for target in group:
            group_numposes[i] += len(target_lines[target])
    for _ in cluster_groups:
        # iteratively assign group with most poses to fold with fewest poses
        maxgroup = group_numposes.index(np.max(group_numposes))
        if randomize:
            space = np.max(fold_numposes) - np.array(fold_numposes)
            tot = np.sum(space)
            if tot == 0:
                minfold = np.random.choice(numfolds)
            else:  # weighted selection, prefer spots with more free space
                choice = np.random.choice(tot)
                tot = 0
                for i in range(len(space)):
                    tot += space[i]
                    if choice < tot:
                        minfold = i
                        break
        else:
            minfold = fold_numposes.index(np.min(fold_numposes))
        folds[minfold].extend(cluster_groups[maxgroup])
        fold_numposes[minfold] += group_numposes[maxgroup]
        group_numposes[maxgroup] = -1
        for t in cluster_groups[maxgroup]:
            foldmap[t] = minfold
    print('Poses per fold: {}'.format(fold_numposes))
    for f in folds:
        f.sort()
    return folds, foldmap

def index(a, x):
    'Locate the leftmost value exactly equal to x'
    i = bisect.bisect_left(a, x)
    if i != len(a) and a[i] == x:
        return i
    else:
        return -1

In [None]:

folds, foldmap = createFolds(
            clusters, 3, target_lines, 1)


In [None]:
with open('folds_using_0.25_thresh_pocketmatch.pkl', 'wb') as f:
    pickle.dump(folds, f)
# with open('folds_using_0.25_thresh_pocketmatch.pkl', 'rb') as f:
#     folds = pickle.load(f)
