## Imports

# Clustering GrASP atoms into pockets

In [1]:
import os
import logging
import warnings
import pickle

warnings.filterwarnings('ignore') # ignores warnings
logging.disable(logging.CRITICAL) # disables logging

from sklearn.cluster import AgglomerativeClustering
import numpy as np
from prointvar.pdbx import PDBXreader, PDBXwriter

## Functions

In [2]:
def cluster_atoms_average(all_coords, predicted_probs, threshold=.5, score_type='mean', **kwargs):
    """Cluster atoms into sites with average linkage clustering.

    Parameters
    ----------
    all_coords: numpy.ndarray
        Protein atomic coordinates.
        
    predicted_probs: numpy.ndarray
        Model predicted binding site probabilities for each atom.

    threshold: float
        Probability threshold to classify atoms as site/non-site.

    score_type: str
        Option for atom score aggregation.

    **kwargs
        Additional agglomerative clustering kwargs.

    Returns
    -------
    numpy.ndarray
        Coordinates of all binding site atoms.
    
    numpy.ndarray
        Sorted cluster ids for binding site atoms.


    numpy.ndarray
        Sorted cluster ids for all atoms with -1 representing non-site.
    """

    predicted_labels = predicted_probs > threshold
    if np.sum(predicted_labels) == 0:
        # No positive predictions were made with specified cutoff
        return None, None, None, None
    bind_coords = all_coords[predicted_labels]
    if bind_coords.shape[0] != 1:
        link_clustering = AgglomerativeClustering(linkage='average', **kwargs).fit(bind_coords)
        cluster_ids = link_clustering.labels_
        sorted_ids, site_scores = sort_clusters(cluster_ids, predicted_probs, predicted_labels, score_type=score_type)
    else:
        # Under rare circumstances only one atom may be predicted as the binding pocket. In this case
        # the clustering fails so we'll just call this one atom our best 'cluster'.
        sorted_ids = np.zeros(1)
        singleton_score = np.sum((predicted_probs[predicted_labels])**2)
        site_scores = {0: round(singleton_score, 3)}

    all_ids = -1*np.ones(predicted_labels.shape) 
    all_ids[predicted_labels] = sorted_ids

    return bind_coords, sorted_ids, all_ids, site_scores

def sort_clusters(cluster_ids, probs, labels, score_type='mean'):
    """Sort clusters according to binding site scores.

    Parameters
    ----------
    cluster_ids : numpy.ndarray
        Cluster label for each atom.
        
    probs: numpy.ndarray
        Model predicted binding site probabilities for each atom.
        
    labels: numpy.ndarray
        Model predicting classes (binding/non-binding) for each atom.

    score_type: str
        Option for atom score aggregation.

    Returns
    -------
    numpy.ndarray
        Resorted clusters labels with highest scoring first.
    """

    c_probs = []
    unique_ids = np.unique(cluster_ids)

    for c_id in unique_ids[unique_ids >= 0]:
        if score_type == 'mean':
            c_prob = np.mean(probs[labels][cluster_ids==c_id])
        elif score_type == 'sum':
            c_prob = np.sum(probs[labels][cluster_ids==c_id])
        elif score_type == 'square':
            c_prob = np.sum((probs[labels][cluster_ids==c_id])**2)
        else:
            print('sort_clusters score_type must be mean, sum, or square.')
        c_probs.append(c_prob)
    c_order = np.argsort(c_probs)
    c_probs_sorted = {}
    site_scores = {i: round(prob, 3) for i, prob in enumerate(c_probs)}
    sorted_ids = -1*np.ones(cluster_ids.shape)
    for new_c in range(len(c_order)):
        old_c = c_order[new_c]
        sorted_ids[cluster_ids == old_c] = new_c
        c_probs_sorted[new_c] = round(c_probs[old_c], 3)
    #print(site_scores, c_probs_sorted)
        
    return sorted_ids, c_probs_sorted

def save_to_pickle(variable, file_path):
    with open(file_path, 'wb') as file:
        pickle.dump(variable, file)

def read_from_pickle(file_path):
    with open(file_path, 'rb') as file:
        return pickle.load(file)

## Input directories

In [4]:
preds_dir = "./../DATA/GrASP_PDB_preds_V2"
pred_files = os.listdir(preds_dir)

In [7]:
rep_chains_files = read_from_pickle("./results/PDB_rep_chains_files.pkl")

In [8]:
len(rep_chains_files)

4037

In [9]:
rep_chains_LIGYSIS = [el.split(".")[0] for el in rep_chains_files]

In [10]:
len(rep_chains_LIGYSIS)

4037

In [11]:
rep_chains_LIGYSIS[:5]

['6cph_D', '7sol_C', '4npj_A', '6y7f_B', '4iqy_B']

## Obtaining sites, memberships and scores

In [12]:
all_sites = {}
all_sites_scores = {}
no_sites = []

In [13]:
for i, pred_file in enumerate(pred_files):
    pred_file_id = "_".join(pred_file.split("_")[:2])
    print(pred_file_id)
    if pred_file_id not in rep_chains_LIGYSIS:
        continue
        
    if i % 100 == 0:
        print(i)
    #break
    #print(pred_file_id)
    # if pred_file_id != "1mkk_B":
    #     continue
    pred_df = PDBXreader(inputfile = os.path.join(preds_dir, pred_file)).atoms(format_type = "pdb", excluded=())
    pred_df['XYZ'] = list(zip(pred_df['Cartn_x'], pred_df['Cartn_y'], pred_df['Cartn_z']))
    coords = np.array(list(pred_df[['Cartn_x', 'Cartn_y', 'Cartn_z']].itertuples(index=False, name=None)))
    probs = np.array(pred_df.B_iso_or_equiv.tolist())
    bind_coords, sorted_ids, all_ids, site_scores = cluster_atoms_average(coords, probs, threshold=.3, n_clusters=None, distance_threshold=15, score_type='square')
    #print(sorted_ids, site_scores)
    if bind_coords is None:
        no_sites.append(pred_file_id)
        continue
    uids = set(sorted_ids)
    sites_ress_dict = {}
    for uid in uids:
        site_coords = bind_coords[sorted_ids == uid]
        site_coords_list = list(map(tuple, site_coords))
        site_ress = pred_df.query('XYZ in @site_coords_list').label_seq_id.unique().tolist()
        sites_ress_dict[int(uid)] = sorted(site_ress)


    all_sites[pred_file_id] = sites_ress_dict
    all_sites_scores[pred_file_id] = site_scores
    print(i, pred_file_id, site_scores)

2fj9_A
0
0 2fj9_A {0: 22.493}
3g4h_A
1 3g4h_A {0: 28.607}
2a5j_A
2 2a5j_A {0: 39.438}
6lc1_J
3 6lc1_J {0: 11.458}
6osa_R
4 6osa_R {0: 0.26, 1: 38.488}
5iyd_I
5 5iyd_I {0: 0.276, 1: 0.897, 2: 10.549}
6sss_D
6 6sss_D {0: 0.185, 1: 7.546, 2: 11.217}
4l7n_A
7 4l7n_A {0: 42.563}
2i7n_B
8 2i7n_B {0: 0.792, 1: 1.857, 2: 18.617}
7xt2_A
9 7xt2_A {0: 0.458, 1: 3.47, 2: 33.534}
2zrt_G
10 2zrt_G {0: 10.012, 1: 23.1}
7umv_A
11 7umv_A {0: 37.824}
3pm0_A
12 3pm0_A {0: 35.16}
7kjo_B
13 7kjo_B {0: 0.49, 1: 5.369}
5o76_A
14 5o76_A {0: 5.543}
2hhl_A
15 2hhl_A {0: 0.596, 1: 20.86}
5m5e_D
16 5m5e_D {0: 0.264, 1: 9.607}
1bhi_A
17 1bhi_A {0: 13.37}
7dpx_A
18 7dpx_A {0: 14.076}
6o1s_E
19 6o1s_E {0: 56.579}
1wyh_A
20 1wyh_A {0: 2.481}
8v8u_B
21 8v8u_B {0: 0.137, 1: 1.048, 2: 3.423, 3: 4.857}
7o4p_A
22 7o4p_A {0: 1.745}
8g0l_A
23 8g0l_A {0: 21.276, 1: 40.479}
6ajk_A
24 6ajk_A {0: 60.247}
7n1r_B
25 7n1r_B {0: 54.583}
7thi_A
26 7thi_A {0: 15.779}
6khf_A
27 6khf_A {0: 41.214}
5lg0_A
28 5lg0_A {0: 1.923, 1: 11.07}


In [14]:
print(no_sites)
print(len(all_sites))
print(len(all_sites_scores))

['2ea5_A', '2uzp_B', '6oqa_H', '4f14_A', '3eb5_A', '2jg9_F']
4030
4030


In [15]:
save_to_pickle(all_sites, "./results/GrASP_pockets_dict_RIGHT_CLUSTERING.pkl")
save_to_pickle(all_sites_scores, "./results/GrASP_pocket_scores_dict_RIGHT_CLUSTERING.pkl")

#### Old with wrong clustering

In [248]:
for i, pred_file in enumerate(pred_files):
    pred_file_id = "_".join(pred_file.split("_")[:2])
    print(pred_file_id)
    if pred_file_id not in rep_chains_LIGYSIS:
        continue
        
    if i % 100 == 0:
        print(i)
    #break
    #print(pred_file_id)
    # if pred_file_id != "1mkk_B":
    #     continue
    pred_df = PDBXreader(inputfile = os.path.join(preds_dir, pred_file)).atoms(format_type = "pdb", excluded=())
    pred_df['XYZ'] = list(zip(pred_df['Cartn_x'], pred_df['Cartn_y'], pred_df['Cartn_z']))
    coords = np.array(list(pred_df[['Cartn_x', 'Cartn_y', 'Cartn_z']].itertuples(index=False, name=None)))
    probs = np.array(pred_df.B_iso_or_equiv.tolist())
    bind_coords, sorted_ids, all_ids, site_scores = cluster_atoms_average(coords, probs, threshold=.3, n_clusters=None, distance_threshold=15, score_type='square')
    #print(sorted_ids, site_scores)
    if bind_coords is None:
        no_sites.append(pred_file_id)
        continue
    uids = set(sorted_ids)
    sites_ress_dict = {}
    for uid in uids:
        site_coords = bind_coords[sorted_ids == uid]
        site_coords_list = list(map(tuple, site_coords))
        site_ress = pred_df.query('XYZ in @site_coords_list').label_seq_id.unique().tolist()
        sites_ress_dict[int(uid)] = sorted(site_ress)


    all_sites[pred_file_id] = sites_ress_dict
    all_sites_scores[pred_file_id] = site_scores
    print(i, pred_file_id, site_scores)

2fj9_A
0
0 2fj9_A {0: 7.393, 1: 15.1}
3g4h_A
1 3g4h_A {0: 11.942, 1: 16.665}
2a5j_A
2 2a5j_A {0: 18.133, 1: 21.305}
6lc1_J
3 6lc1_J {0: 3.573, 1: 7.884}
6osa_R
4 6osa_R {0: 0.26, 1: 38.488}
5iyd_I
5 5iyd_I {0: 0.897, 1: 10.825}
6sss_D
6 6sss_D {0: 7.731, 1: 11.217}
4l7n_A
7 4l7n_A {0: 10.294, 1: 32.269}
2i7n_B
8 2i7n_B {0: 2.649, 1: 18.617}
7xt2_A
9 7xt2_A {0: 3.928, 1: 33.534}
2zrt_G
10 2zrt_G {0: 10.012, 1: 23.1}
7umv_A
11 7umv_A {0: 4.269, 1: 33.554}
3pm0_A
12 3pm0_A {0: 3.812, 1: 31.348}
7kjo_B
13 7kjo_B {0: 0.49, 1: 5.369}
5o76_A
14 5o76_A {0: 2.547, 1: 2.997}
2hhl_A
15 2hhl_A {0: 0.596, 1: 20.86}
5m5e_D
16 5m5e_D {0: 0.264, 1: 9.607}
1bhi_A
17 1bhi_A {0: 3.455, 1: 9.915}
7dpx_A
18 7dpx_A {0: 4.565, 1: 9.51}
6o1s_E
19 6o1s_E {0: 22.79, 1: 33.789}
1wyh_A
20 1wyh_A {0: 0.384, 1: 2.098}
8v8u_B
21 8v8u_B {0: 4.607, 1: 4.857}
7o4p_A
22 7o4p_A {0: 0.417, 1: 1.328}
8g0l_A
23 8g0l_A {0: 21.276, 1: 40.479}
6ajk_A
24 6ajk_A {0: 24.088, 1: 36.159}
7n1r_B
25 7n1r_B {0: 8.015, 1: 46.568}
7thi_

In [250]:
print(no_sites)
print(len(all_sites))
print(len(all_sites_scores))

['2ea5_A', '2uzp_B', '6oqa_H', '4f14_A', '3eb5_A', '2jg9_F']
4030
4030


In [251]:
save_to_pickle(all_sites, "./results/GrASP_pockets_dict_V2.pkl")
save_to_pickle(all_sites_scores, "./results/GrASP_pocket_scores_dict_V2.pkl")