In [10]:
import gc
import torch
import torch.nn as nn
import requests
import glob, os, math
import numpy as np
from transformers import BertModel, BertTokenizer, AutoModel
from autogluon.tabular import TabularDataset, TabularPredictor
from utils.top_accuracies import calculate_top_n_accuracies, topk, topN

from utils.extract_sequence import extract_sequence
from utils.pocket_feature import pocket_feature
from utils.sequence_indices import sequence_indices
from utils.pocket_coordinates import pocket_coordinates

In [14]:
N_ATOMS = 9
MODEL_PATH = "/home/mkhokhar21/Documents/COSBI/Allostery_Paper/prot_bert_mtl"
base_url = "https://files.rcsb.org/download"
pdb_dir = "/home/mkhokhar21/Documents/COSBI/Allostery_Paper/data/pdbs/"
pocket_dir = "/home/mkhokhar21/Documents/COSBI/Allostery_Paper/data/pockets/"
pdb_id = "3PEE"
chain_id = "A"

In [3]:
class MultiTaskModel(nn.Module):
    def __init__(self, model_name, num_labels_task1, num_labels_task2):
        super(MultiTaskModel, self).__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.head1 = nn.Linear(self.encoder.config.hidden_size, num_labels_task1)
        self.head2 = nn.Linear(self.encoder.config.hidden_size, num_labels_task2)

    def forward(self, input1=None, input2=None):
        output1, output2 = None, None
        encoder_output1, encoder_output2 = None, None

        if input1 is not None:
            encoder_output1 = self.encoder(**input1).last_hidden_state
            output1 = self.head1(encoder_output1)

        if input2 is not None:
            encoder_output2 = self.encoder(**input2).last_hidden_state
            output2 = self.head2(encoder_output2)

        return (output1, output2), (encoder_output1, encoder_output2)

In [4]:
def get_res_data(poc_res_emb, pocket_coord, pocket_features, labels):
    X = []
    Y = []

    for i in range(min(len(poc_res_emb), len(pocket_coord))):
        seq_emb = []
        for res_idx in range(min(len(poc_res_emb[i]), len(pocket_coord[i]))):
            seq_emb.append(poc_res_emb[i][res_idx])
        seq_emb = np.array(seq_emb).mean(axis=0)
        poc = pocket_features[i]
        X.append(np.concatenate((seq_emb, poc)))
#### Test - begin ####
        Y.append(labels[i])
#### Test - end ####

    return X, Y

def do_it(pdb_id, chain_id):
    pdb_path = os.path.join(pdb_dir, f"{pdb_id}.pdb")
    pocket_path = os.path.join(pocket_dir, f"{pdb_id}_out")

    #### Test - begin ####
    ASD_path = "/home/mkhokhar21/Documents/COSBI/Allostery_Paper/data/source_data/ASD_Release_201909_AS.txt"

    asd = None
    with open(ASD_path, "r") as f:
        asd = f.readlines()

    mod_id, modulator, residues = None, None, None
    for line in asd[1:]:
        line = line.strip().split("\t")
        pdb, modulator, chain_id, mod_id = line[4], line[6], line[7], line[11]

        if pdb != pdb_id:
            continue

        if len(set(chain_id.split(";"))) != 1:
            continue
        chain_id = chain_id[0]

        if len(set(modulator.split(";"))) != 1:
            continue
        modulator = modulator.split(";")[0]

        # extract residues
        res_raw = [
            res.replace(":", ",").split(",") for res in line[-1].split("; ")
        ]
        # residue_clean format: chain id + residue type + residue number
        residues = [
            [res[0][-1], ch[:3], ch[3:]] for res in res_raw for ch in res[1:]
        ]
        # select only residues in the same chain of modulator
        residues = [res for res in residues if res[0] == chain_id]

        break
    #### Test - end ####


    if not os.path.exists(pdb_path):
        response = requests.get(f"{base_url}/{pdb_id}.pdb")
        if response.status_code == 200:  # Check if the request was successful
            with open(pdb_path, 'wb') as file:
                file.write(response.content)
            print(f"PDB file {pdb_id}.pdb downloaded successfully.")
        else:
            raise Exception(f"Failed to download {pdb_id}.pdb. Check if the PDB ID is correct.")

    sequence = extract_sequence(pdb_path, chain_id)

    if len(sequence) <= 10:
        raise Exception("Sequence is too short.")

    if not os.path.exists(pocket_path):
        os.system(f"fpocket -f {pdb_path} -k {chain_id}")
        os.system(f"mv {os.path.join(pdb_dir, pdb_id)}_out {pocket_dir}")

    #### Test - begin ####
    protein = None
    lig_x, lig_y, lig_z, lig_cnt = 0, 0, 0, 0

    with open(pdb_path, "r") as f:
        protein = f.readlines()

    for line in protein:
        if (
            line[:6] == "HETATM" and modulator == line[17:20].strip()
            and line[21] == chain_id and mod_id == line[22:26].strip()
        ):
            lig_x += float(line[30:38])
            lig_y += float(line[38:46])
            lig_z += float(line[46:54])
            lig_cnt += 1

    lig_x /= lig_cnt
    lig_y /= lig_cnt
    lig_z /= lig_cnt
    #### Test - end ####

    pocket_names = glob.glob(f"{pocket_path}/pockets/*.pdb")
    pocket_names = sorted(
        pocket_names,
        key=lambda x: int(x.split("pocket")[-1].split("_")[0])
    )

    pockets_feats = pocket_feature(f"{pocket_path}/{pdb_id}_info.txt")
    selected_idxs = []
    pocket_residue_indices = []

    #### Test - begin ####
    atomTarget = {}
    for res in residues:
        atomTarget[f'{res[1]}{res[2]}'] = res[0]

    dists = []
    countsPockets = [] # for atom count
    #### Test - end ####

    for idx, pocket_name in enumerate(pocket_names):
        pocket = None
        with open(pocket_name, "r") as f:
            pocket = f.readlines()

    #### Test - begin ####
        poc_x, poc_y, poc_z = 0, 0, 0
        pocketAtomCount = 0
    #### Test - end ####

        poc_cnt = 0
        residue_indices = set()

        for line in pocket:
            if line[:4] == "ATOM":
                poc_cnt += 1
                residue_index = line[22:26].strip()
                atom = line[17:20] + residue_index
                residue_indices.add(residue_index)

    #### Test - begin ####
                poc_x += float(line[30:38])
                poc_y += float(line[38:46])
                poc_z += float(line[46:54])
                chainID = line[21]
                if atom in atomTarget and atomTarget[atom] == chainID:
                    pocketAtomCount += 1
    #### Test - end ####

        if poc_cnt == 0:
            continue

    #### Test - begin ####
        poc_x /= poc_cnt
        poc_y /= poc_cnt
        poc_z /= poc_cnt
        dist = math.sqrt(
            (poc_x - lig_x) ** 2 + (poc_y - lig_y) ** 2 +
            (poc_z - lig_z) ** 2
        )

        dists.append(dist)
        countsPockets.append(pocketAtomCount)
    #### Test - end ####

        selected_idxs.append(idx)
        pocket_residue_indices.append(list(residue_indices))

    if len(selected_idxs) <= 2:
        raise Exception("Too few pockets extracted.")

    pocket_features = [pockets_feats[idx] for idx in selected_idxs]

    seq_indices = sequence_indices(pdb_id, chain_id)

    #### Test - begin ####
    dist_min_idx = np.argmin(dists)
    labels = [1 if item >= N_ATOMS else 0 for item in countsPockets] # for atom count
    labels[dist_min_idx] = 1

    seq_labels = ['N'] * len(sequence)
    for i in range(len(labels)):
            if labels[i] == 1:
                for residue_index in pocket_residue_indices[i]:
                    if residue_index in seq_indices and seq_indices[residue_index] < len(sequence):
                        seq_labels[seq_indices[residue_index]] = 'Y'
    #### Test - end ####

    pocket_coord = pocket_coordinates(pdb_path, f"{pocket_path}/pockets/", pdb_id, chain_id, pocket_residue_indices)

    tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert_bfd", do_lower_case=False)
    model = MultiTaskModel("Rostlab/prot_bert_bfd", 2, 3)
    state_dict = torch.load("/home/mkhokhar21/Documents/COSBI/Allostery_Paper/prot_bert_mtl/prot_bert_mtl.bin")
    model.load_state_dict(state_dict)
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model = model.eval()

    poc_res_emb = []

    #### Test - begin ####
    poc_labels = []
    #### Test - end ####

    with torch.no_grad():
        seq = " ".join(sequence)
        encoding = tokenizer.batch_encode_plus(
            [seq],
            add_special_tokens=True,
            padding='max_length'
        )
        input_ids = torch.tensor(encoding['input_ids']).to(device)
        attention_mask = torch.tensor(encoding['attention_mask']).to(device)
        inputs = {'input_ids': input_ids, 'attention_mask': attention_mask}
        _, (last_hidden_state, _) = model(input1=inputs)
        embedding = last_hidden_state.cpu().numpy()

        seq_len = (attention_mask[0] == 1).sum()
        token_emb = embedding[0][1:seq_len-1]

        for i in range(len(pocket_residue_indices)):
            add_pocket = True
            cur_poc_emb = []

    #### Test - begin ####
            poc_labels.append(labels[i])
    #### Test - end ####

            for idx in pocket_residue_indices[i]:
                try:
                    token = token_emb[seq_indices[idx]]
                    cur_poc_emb.append(token)
                except Exception as e:
                    add_pocket = False
    #### Test - begin ####
                    poc_labels.pop()
    #### Test - end ####
                    break

            if add_pocket:
                poc_res_emb.append(cur_poc_emb)

    del model
    torch.cuda.empty_cache()
    gc.collect()

    X_Test, Y_Test = get_res_data(poc_res_emb, pocket_coord, pocket_features, labels)
    X_Test, Y_Test = np.array(X_Test), np.array(Y_Test)
    test_data = np.concatenate((X_Test, Y_Test.reshape(-1, 1)), axis=1)
    test_data = TabularDataset(test_data)
    test_data.columns = [str(i) for i in range(1, X_Test.shape[1] + 2)]
    label = str(X_Test.shape[1] + 1)
    predictor = TabularPredictor.load("/home/mkhokhar21/Documents/COSBI/Allostery_Paper/src/AutogluonModels/MTL_All")

    y_test_label = test_data[label]
    y_test_nolab = test_data.drop(columns=[label])

    y_pred = predictor.predict_proba(y_test_nolab)

    return y_pred.to_numpy()[:, 1], Y_Test, pocket_residue_indices

In [15]:
y_pred, Y_Test, pocket_residue_indices = do_it(pdb_id, chain_id)
paired = list(zip(y_pred, Y_Test, pocket_residue_indices))
paired_sorted = sorted(paired, key=lambda x: x[0], reverse=True)

top3 = [paired_sorted[i] for i in range(min(len(paired_sorted), 3))]

for top in top3:
    print(top)

Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.


(0.8430039286613464, 1, ['153', '104', '232', '203', '224', '36', '57', '34', '54', '28', '249', '209', '151', '211', '32', '208', '221', '231'])
(0.013739123940467834, 0, ['153', '46', '110', '109', '154', '202', '108', '50', '155', '47', '203', '201'])
(0.01281462050974369, 0, ['230', '200', '198', '199', '197', '229', '211', '212', '231'])


In [16]:
for top in top3:
    cur_res = "select :"
    for res in top[-1]:
        cur_res = f"{cur_res}{res},"
    print(f"{top[0]} - {top[1]} - {cur_res[:-1]}")

0.8430039286613464 - 1 - select :153,104,232,203,224,36,57,34,54,28,249,209,151,211,32,208,221,231
0.013739123940467834 - 0 - select :153,46,110,109,154,202,108,50,155,47,203,201
0.01281462050974369 - 0 - select :230,200,198,199,197,229,211,212,231


In [13]:
result = []
pdbs = ["1Q5O", "2FPL", "3BCR", "3PEE", "4HO6"]
for pdb in pdbs:
    result.append({"pdb": pdb, "top3": do_it(pdb, "A")})

Some weights of the model checkpoint at ../prot_bert_allosteric were not used when initializing BertModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at ../prot_bert_allosteric and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum leng

(0.99999523, 1, ['574', '591', '564', '545', '635', '595', '584', '580', '636', '592', '583', '582', '593', '632', '581'])
(0.00015850604, 0, ['631', '641', '633', '630', '640', '634', '639', '643', '638'])
(0.00011590509, 0, ['569', '547', '570', '635', '636', '638', '592', '567', '566', '568'])


Some weights of the model checkpoint at ../prot_bert_allosteric were not used when initializing BertModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at ../prot_bert_allosteric and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum leng

(0.9992513, 1, ['112', '111', '110', '113', '108'])
(0.00021088084, 0, ['240', '193', '243', '244', '192', '247', '248', '196'])
(0.00018359804, 0, ['139', '206', '202', '203', '249', '204', '247', '248', '199', '205'])


Some weights of the model checkpoint at ../prot_bert_allosteric were not used when initializing BertModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at ../prot_bert_allosteric and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum leng

(0.9918006, 1, ['569', '280', '610', '570', '572', '612', '613', '283', '382', '611', '284', '285', '282', '281', '571', '614', '287', '770'])
(0.98959947, 0, ['377', '574', '484', '292', '385', '341', '284', '282', '455', '672', '135', '573', '283', '383', '339', '569', '136', '286', '285', '676', '675', '674', '133', '673', '280', '88', '378', '287'])
(0.8901069, 0, ['71', '64', '193', '240', '60', '68', '75', '67', '191', '72', '227'])


Some weights of the model checkpoint at ../prot_bert_allosteric were not used when initializing BertModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at ../prot_bert_allosteric and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum leng

(0.27388626, 1, ['36', '209', '232', '54', '221', '203', '28', '208', '231', '249', '151', '153', '211', '57', '104', '34', '224', '32'])
(0.010215629, 0, ['139', '136', '135', '91', '37', '103', '93', '105'])
(0.0075611584, 0, ['185', '144', '183', '143', '141', '140', '186', '184'])


Some weights of the model checkpoint at ../prot_bert_allosteric were not used when initializing BertModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at ../prot_bert_allosteric and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum leng

(0.99999034, 1, ['257', '40', '247', '112', '111', '253', '249', '114', '113', '254', '43', '39', '248', '116', '117'])
(0.0047794143, 0, ['86', '23', '137', '197', '196', '222', '143', '198', '159', '172', '106', '142', '223', '138', '170', '160', '13', '224', '192', '108', '109', '225', '139', '174', '221', '144'])
(0.0007168664, 0, ['8', '86', '106', '83', '6', '194', '108', '24', '80', '9', '85'])


In [16]:
for top in top3:
    cur_res = "select :"
    for res in top[-1]:
        cur_res = f"{cur_res}{res},"
    print(f"{top[0]} - {top[1]} - {cur_res[:-1]}")

0.8491735458374023 - 1 - select :580,593,564,632,581,583,591,592,635,574,582,595,545,636,584
0.028418462723493576 - 0 - select :569,547,567,568,570,638,566,592,635,636
0.012626996263861656 - 0 - select :579,632,517,577,624,578,628,516


In [44]:
first_set = set(x for x in result[4]["top3"][0][-1])
for top in result[4]["top3"]:
    cur_res = "select :"
    cur_set = cur_set = set(x for x in top[-1])
    cur_set = cur_set - first_set
    for res in list(cur_set):
        cur_res = f"{cur_res}{res},"
    print(cur_res)

select :
select :86,23,137,197,196,222,143,198,159,172,106,142,223,138,170,160,13,224,192,108,109,225,139,174,221,144,
select :8,86,83,106,6,194,24,108,80,9,85,


In [7]:
residues = []
first_set = set(x for x in top3[0][-1])
cur_set= None
for top in top3:
    cur_res = 'select :'
    cur_set = set(x for x in top[-1])
    cur_set = cur_set - first_set
    for res in top[-1]:
        cur_res = f"{cur_res}{res},"
    residues.append(cur_res[:-1])
    
residues, cur_set

(['574,591,564,545,635,595,584,580,636,592,583,582,593,632,581',
  '631,641,633,630,640,634,639,643,638',
  '569,547,570,635,636,638,592,567,566,568'],
 {'547', '566', '567', '568', '569', '570', '638'})

In [4]:
import gc
import torch
import requests
import glob, os, math
import numpy as np
from transformers import BertModel, BertTokenizer
import xgboost as xgb

from utils.extract_sequence import extract_sequence
from utils.pocket_feature import pocket_feature
from utils.sequence_indices import sequence_indices
from utils.pocket_coordinates import pocket_coordinates

N_ATOMS = 9
MODEL_PATH = "../prot_bert_allosteric"
base_url = "https://files.rcsb.org/download"
pdb_dir = "../data/pdbs/"
pocket_dir = "../data/pockets/"
pdb_id = "5DKK"
chain_id = "A"

is_test = False

pdb_path = os.path.join(pdb_dir, f"{pdb_id}.pdb")
pocket_path = os.path.join(pocket_dir, f"{pdb_id}_out")

#### Test - begin ####
if is_test:
    ASD_path = "../data/source_data/ASD_Release_201909_AS.txt"

    asd = None
    with open(ASD_path, "r") as f:
        asd = f.readlines()

    mod_id, modulator, residues = None, None, None
    for line in asd[1:]:
        line = line.strip().split("\t")
        pdb, modulator, chain_id, mod_id = line[4], line[6], line[7], line[11]

        if pdb != pdb_id:
            continue

        if len(set(chain_id.split(";"))) != 1:
            continue
        chain_id = chain_id[0]

        if len(set(modulator.split(";"))) != 1:
            continue
        modulator = modulator.split(";")[0]

        # extract residues
        res_raw = [
            res.replace(":", ",").split(",") for res in line[-1].split("; ")
        ]
        # residue_clean format: chain id + residue type + residue number
        residues = [
            [res[0][-1], ch[:3], ch[3:]] for res in res_raw for ch in res[1:]
        ]
        # select only residues in the same chain of modulator
        residues = [res for res in residues if res[0] == chain_id]

        break
#### Test - end ####


if not os.path.exists(pdb_path):
    response = requests.get(f"{base_url}/{pdb_id}.pdb")
    if response.status_code == 200:  # Check if the request was successful
        with open(pdb_path, 'wb') as file:
            file.write(response.content)
        print(f"PDB file {pdb_id}.pdb downloaded successfully.")
    else:
        raise Exception(f"Failed to download {pdb_id}.pdb. Check if the PDB ID is correct.")

sequence = extract_sequence(pdb_path, chain_id)

if len(sequence) <= 10:
    raise Exception("Sequence is too short.")

if not os.path.exists(pocket_path):
    os.system(f"fpocket -f {pdb_path} -k {chain_id}")
    os.system(f"mv {os.path.join(pdb_dir, pdb_id)}_out {pocket_dir}")

#### Test - begin ####
if is_test:
    protein = None
    lig_x, lig_y, lig_z, lig_cnt = 0, 0, 0, 0

    with open(pdb_path, "r") as f:
        protein = f.readlines()

    for line in protein:
        if (
            line[:6] == "HETATM" and modulator == line[17:20].strip()
            and line[21] == chain_id and mod_id == line[22:26].strip()
        ):
            lig_x += float(line[30:38])
            lig_y += float(line[38:46])
            lig_z += float(line[46:54])
            lig_cnt += 1

    lig_x /= lig_cnt
    lig_y /= lig_cnt
    lig_z /= lig_cnt
#### Test - end ####

pocket_names = glob.glob(f"{pocket_path}/pockets/*.pdb")
pocket_names = sorted(
    pocket_names,
    key=lambda x: int(x.split("pocket")[-1].split("_")[0])
)

pockets_feats = pocket_feature(f"{pocket_path}/{pdb_id}_info.txt")
selected_idxs = []
pocket_residue_indices = []

#### Test - begin ####
if is_test:
    atomTarget = {}
    for res in residues:
        atomTarget[f'{res[1]}{res[2]}'] = res[0]

    dists = []
    countsPockets = [] # for atom count
#### Test - end ####

for idx, pocket_name in enumerate(pocket_names):
    pocket = None
    with open(pocket_name, "r") as f:
        pocket = f.readlines()

#### Test - begin ####
    if is_test:
        poc_x, poc_y, poc_z = 0, 0, 0
        pocketAtomCount = 0
#### Test - end ####

    poc_cnt = 0
    residue_indices = set()

    for line in pocket:
        if line[:4] == "ATOM":
            poc_cnt += 1
            residue_index = line[22:26].strip()
            atom = line[17:20] + residue_index
            residue_indices.add(residue_index)

#### Test - begin ####
            if is_test:
                poc_x += float(line[30:38])
                poc_y += float(line[38:46])
                poc_z += float(line[46:54])
                chainID = line[21]
                if atom in atomTarget and atomTarget[atom] == chainID:
                    pocketAtomCount += 1
#### Test - end ####

    if poc_cnt == 0:
        continue

#### Test - begin ####
    if is_test:
        poc_x /= poc_cnt
        poc_y /= poc_cnt
        poc_z /= poc_cnt
        dist = math.sqrt(
            (poc_x - lig_x) ** 2 + (poc_y - lig_y) ** 2 +
            (poc_z - lig_z) ** 2
        )

        dists.append(dist)
        countsPockets.append(pocketAtomCount)
#### Test - end ####

    selected_idxs.append(idx)
    pocket_residue_indices.append(list(residue_indices))

if len(selected_idxs) <= 2:
    raise Exception("Too few pockets extracted.")

pocket_features = [pockets_feats[idx] for idx in selected_idxs]

seq_indices = sequence_indices(pdb_id, chain_id)

#### Test - begin ####
if is_test:
    dist_min_idx = np.argmin(dists)
    labels = [1 if item >= N_ATOMS else 0 for item in countsPockets] # for atom count
    labels[dist_min_idx] = 1

    seq_labels = ['N'] * len(sequence)
    for i in range(len(labels)):
            if labels[i] == 1:
                for residue_index in pocket_residue_indices[i]:
                    if residue_index in seq_indices and seq_indices[residue_index] < len(sequence):
                        seq_labels[seq_indices[residue_index]] = 'Y'
#### Test - end ####

pocket_coord = pocket_coordinates(pdb_path, f"{pocket_path}/pockets/", pdb_id, chain_id, pocket_residue_indices)

tokenizer = BertTokenizer.from_pretrained(MODEL_PATH, do_lower_case=False )
model = BertModel.from_pretrained(MODEL_PATH)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model = model.eval()

seq_emb = None
poc_res_emb = []

#### Test - begin ####
if is_test:
    poc_labels = []
#### Test - end ####

with torch.no_grad():
    seq = " ".join(sequence)
    encoding = tokenizer.batch_encode_plus(
        [seq],
        add_special_tokens=True,
        padding='max_length'
    )
    input_ids = torch.tensor(encoding['input_ids']).to(device)
    attention_mask = torch.tensor(encoding['attention_mask']).to(device)
    embedding = model(input_ids=input_ids, attention_mask=attention_mask)[0].cpu().numpy()

    seq_len = (attention_mask[0] == 1).sum()
    token_emb = embedding[0][1:seq_len-1]
    seq_emb = token_emb.mean(axis=0)

    for i in range(len(pocket_residue_indices)):
        add_pocket = True
        cur_poc_emb = []

#### Test - begin ####
        if is_test:
            poc_labels.append(labels[i])
#### Test - end ####

        for idx in pocket_residue_indices[i]:
            try:
                token = token_emb[seq_indices[idx]]
                cur_poc_emb.append(token)
            except Exception as e:
                add_pocket = False
#### Test - begin ####
                if is_test:
                    poc_labels.pop()
#### Test - end ####
                break
        
        if add_pocket:
            poc_res_emb.append(cur_poc_emb)

del model
torch.cuda.empty_cache()
gc.collect()

def get_res_data(poc_res_emb, pocket_coord):
    X = []
    Y = []

    for i in range(min(len(poc_res_emb), len(pocket_coord))):
        seq_emb = []
        for res_idx in range(min(len(poc_res_emb[i]), len(pocket_coord[i]))):
            seq_emb.append(poc_res_emb[i][res_idx])
        seq_emb = np.array(seq_emb).mean(axis=0)
        poc = pocket_features[i]
        X.append(np.concatenate((seq_emb, poc)))
#### Test - begin ####
        if is_test:
            Y.append(labels[i])
#### Test - end ####

    if is_test:
        return X, Y
    else:
        return X

if is_test:
    X_Test, Y_Test = get_res_data(poc_res_emb, pocket_coord)
    dtest = xgb.DMatrix(X_Test, label=Y_Test)
else:
    X_Test = get_res_data(poc_res_emb, pocket_coord)
    dtest = xgb.DMatrix(X_Test)
bst = xgb.Booster()
bst.load_model('./xgboost.model')
y_pred = bst.predict(dtest)

if is_test:
    paired = list(zip(y_pred, Y_Test, pocket_residue_indices))
else:
    paired = list(zip(y_pred, pocket_residue_indices))
paired_sorted = sorted(paired, key=lambda x: x[0], reverse=True)

top3 = [paired_sorted[i] for i in range(min(len(paired_sorted), 3))]

for top in top3:
    cur_res = "select :"
    for res in top[-1]:
        cur_res = f"{cur_res}{res},"
    print(f"{top[0]} - {cur_res[:-1]}")
    
first_set = set(x for x in top3[0][-1])
for top in top3:
    cur_res = "select :"
    cur_set = cur_set = set(x for x in top[-1])
    cur_set = cur_set - first_set
    for res in list(cur_set):
        cur_res = f"{cur_res}{res},"
    print(cur_res)

Some weights of the model checkpoint at ../prot_bert_allosteric were not used when initializing BertModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at ../prot_bert_allosteric and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum leng

(0.3962158, ['303', '286', '304', '319', '348', '346', '255', '308', '350', '253', '329', '307', '287', '300', '331', '291', '290', '261', '333', '262', '317', '288'])
(0.0033208348, ['266', '254', '265', '242', '246'])
(0.0008981182, ['353', '321', '326', '355', '328', '327', '318', '352'])


In [6]:
for top in top3:
    cur_res = "select :"
    for res in top[-1]:
        cur_res = f"{cur_res}{res},"
    print(f"{top[0]} - {cur_res[:-1]}")

0.3962157964706421 - select :303,286,304,319,348,346,255,308,350,253,329,307,287,300,331,291,290,261,333,262,317,288
0.003320834832265973 - select :266,254,265,242,246
0.0008981181890703738 - select :353,321,326,355,328,327,318,352


In [7]:
first_set = set(x for x in top3[0][-1])
for top in top3:
    cur_res = "select :"
    cur_set = cur_set = set(x for x in top[-1])
    cur_set = cur_set - first_set
    for res in list(cur_set):
        cur_res = f"{cur_res}{res},"
    print(cur_res)

select :
select :266,254,265,242,246,
select :321,326,328,353,355,327,318,352,


In [8]:
import gc
import torch
import requests
import glob, os, math
import numpy as np
from transformers import BertModel, BertTokenizer
import xgboost as xgb

from utils.extract_sequence import extract_sequence
from utils.pocket_feature import pocket_feature
from utils.sequence_indices import sequence_indices
from utils.pocket_coordinates import pocket_coordinates

N_ATOMS = 9
MODEL_PATH = "../prot_bert_allosteric"
base_url = "https://files.rcsb.org/download"
pdb_dir = "../data/pdbs/"
pocket_dir = "../data/pockets/"
pdb_id = "3LNY"
chain_id = "A"

is_test = False

pdb_path = os.path.join(pdb_dir, f"{pdb_id}.pdb")
pocket_path = os.path.join(pocket_dir, f"{pdb_id}_out")

#### Test - begin ####
if is_test:
    ASD_path = "../data/source_data/ASD_Release_201909_AS.txt"

    asd = None
    with open(ASD_path, "r") as f:
        asd = f.readlines()

    mod_id, modulator, residues = None, None, None
    for line in asd[1:]:
        line = line.strip().split("\t")
        pdb, modulator, chain_id, mod_id = line[4], line[6], line[7], line[11]

        if pdb != pdb_id:
            continue

        if len(set(chain_id.split(";"))) != 1:
            continue
        chain_id = chain_id[0]

        if len(set(modulator.split(";"))) != 1:
            continue
        modulator = modulator.split(";")[0]

        # extract residues
        res_raw = [
            res.replace(":", ",").split(",") for res in line[-1].split("; ")
        ]
        # residue_clean format: chain id + residue type + residue number
        residues = [
            [res[0][-1], ch[:3], ch[3:]] for res in res_raw for ch in res[1:]
        ]
        # select only residues in the same chain of modulator
        residues = [res for res in residues if res[0] == chain_id]

        break
#### Test - end ####


if not os.path.exists(pdb_path):
    response = requests.get(f"{base_url}/{pdb_id}.pdb")
    if response.status_code == 200:  # Check if the request was successful
        with open(pdb_path, 'wb') as file:
            file.write(response.content)
        print(f"PDB file {pdb_id}.pdb downloaded successfully.")
    else:
        raise Exception(f"Failed to download {pdb_id}.pdb. Check if the PDB ID is correct.")

sequence = extract_sequence(pdb_path, chain_id)

if len(sequence) <= 10:
    raise Exception("Sequence is too short.")

if not os.path.exists(pocket_path):
    os.system(f"fpocket -f {pdb_path} -k {chain_id}")
    os.system(f"mv {os.path.join(pdb_dir, pdb_id)}_out {pocket_dir}")

#### Test - begin ####
if is_test:
    protein = None
    lig_x, lig_y, lig_z, lig_cnt = 0, 0, 0, 0

    with open(pdb_path, "r") as f:
        protein = f.readlines()

    for line in protein:
        if (
            line[:6] == "HETATM" and modulator == line[17:20].strip()
            and line[21] == chain_id and mod_id == line[22:26].strip()
        ):
            lig_x += float(line[30:38])
            lig_y += float(line[38:46])
            lig_z += float(line[46:54])
            lig_cnt += 1

    lig_x /= lig_cnt
    lig_y /= lig_cnt
    lig_z /= lig_cnt
#### Test - end ####

pocket_names = glob.glob(f"{pocket_path}/pockets/*.pdb")
pocket_names = sorted(
    pocket_names,
    key=lambda x: int(x.split("pocket")[-1].split("_")[0])
)

pockets_feats = pocket_feature(f"{pocket_path}/{pdb_id}_info.txt")
selected_idxs = []
pocket_residue_indices = []

#### Test - begin ####
if is_test:
    atomTarget = {}
    for res in residues:
        atomTarget[f'{res[1]}{res[2]}'] = res[0]

    dists = []
    countsPockets = [] # for atom count
#### Test - end ####

for idx, pocket_name in enumerate(pocket_names):
    pocket = None
    with open(pocket_name, "r") as f:
        pocket = f.readlines()

#### Test - begin ####
    if is_test:
        poc_x, poc_y, poc_z = 0, 0, 0
        pocketAtomCount = 0
#### Test - end ####

    poc_cnt = 0
    residue_indices = set()

    for line in pocket:
        if line[:4] == "ATOM":
            poc_cnt += 1
            residue_index = line[22:26].strip()
            atom = line[17:20] + residue_index
            residue_indices.add(residue_index)

#### Test - begin ####
            if is_test:
                poc_x += float(line[30:38])
                poc_y += float(line[38:46])
                poc_z += float(line[46:54])
                chainID = line[21]
                if atom in atomTarget and atomTarget[atom] == chainID:
                    pocketAtomCount += 1
#### Test - end ####

    if poc_cnt == 0:
        continue

#### Test - begin ####
    if is_test:
        poc_x /= poc_cnt
        poc_y /= poc_cnt
        poc_z /= poc_cnt
        dist = math.sqrt(
            (poc_x - lig_x) ** 2 + (poc_y - lig_y) ** 2 +
            (poc_z - lig_z) ** 2
        )

        dists.append(dist)
        countsPockets.append(pocketAtomCount)
#### Test - end ####

    selected_idxs.append(idx)
    pocket_residue_indices.append(list(residue_indices))

if len(selected_idxs) <= 2:
    raise Exception("Too few pockets extracted.")

pocket_features = [pockets_feats[idx] for idx in selected_idxs]

seq_indices = sequence_indices(pdb_id, chain_id)

#### Test - begin ####
if is_test:
    dist_min_idx = np.argmin(dists)
    labels = [1 if item >= N_ATOMS else 0 for item in countsPockets] # for atom count
    labels[dist_min_idx] = 1

    seq_labels = ['N'] * len(sequence)
    for i in range(len(labels)):
            if labels[i] == 1:
                for residue_index in pocket_residue_indices[i]:
                    if residue_index in seq_indices and seq_indices[residue_index] < len(sequence):
                        seq_labels[seq_indices[residue_index]] = 'Y'
#### Test - end ####

pocket_coord = pocket_coordinates(pdb_path, f"{pocket_path}/pockets/", pdb_id, chain_id, pocket_residue_indices)

tokenizer = BertTokenizer.from_pretrained(MODEL_PATH, do_lower_case=False )
model = BertModel.from_pretrained(MODEL_PATH)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model = model.eval()

seq_emb = None
poc_res_emb = []

#### Test - begin ####
if is_test:
    poc_labels = []
#### Test - end ####

with torch.no_grad():
    seq = " ".join(sequence)
    encoding = tokenizer.batch_encode_plus(
        [seq],
        add_special_tokens=True,
        padding='max_length'
    )
    input_ids = torch.tensor(encoding['input_ids']).to(device)
    attention_mask = torch.tensor(encoding['attention_mask']).to(device)
    embedding = model(input_ids=input_ids, attention_mask=attention_mask)[0].cpu().numpy()

    seq_len = (attention_mask[0] == 1).sum()
    token_emb = embedding[0][1:seq_len-1]
    seq_emb = token_emb.mean(axis=0)

    for i in range(len(pocket_residue_indices)):
        add_pocket = True
        cur_poc_emb = []

#### Test - begin ####
        if is_test:
            poc_labels.append(labels[i])
#### Test - end ####

        for idx in pocket_residue_indices[i]:
            try:
                token = token_emb[seq_indices[idx]]
                cur_poc_emb.append(token)
            except Exception as e:
                add_pocket = False
#### Test - begin ####
                if is_test:
                    poc_labels.pop()
#### Test - end ####
                break
        
        if add_pocket:
            poc_res_emb.append(cur_poc_emb)

del model
torch.cuda.empty_cache()
gc.collect()

def get_res_data(poc_res_emb, pocket_coord):
    X = []
    Y = []

    for i in range(min(len(poc_res_emb), len(pocket_coord))):
        seq_emb = []
        for res_idx in range(min(len(poc_res_emb[i]), len(pocket_coord[i]))):
            seq_emb.append(poc_res_emb[i][res_idx])
        seq_emb = np.array(seq_emb).mean(axis=0)
        poc = pocket_features[i]
        X.append(np.concatenate((seq_emb, poc)))
#### Test - begin ####
        if is_test:
            Y.append(labels[i])
#### Test - end ####

    if is_test:
        return X, Y
    else:
        return X

if is_test:
    X_Test, Y_Test = get_res_data(poc_res_emb, pocket_coord)
    dtest = xgb.DMatrix(X_Test, label=Y_Test)
else:
    X_Test = get_res_data(poc_res_emb, pocket_coord)
    dtest = xgb.DMatrix(X_Test)
bst = xgb.Booster()
bst.load_model('./xgboost.model')
y_pred = bst.predict(dtest)

if is_test:
    paired = list(zip(y_pred, Y_Test, pocket_residue_indices))
else:
    paired = list(zip(y_pred, pocket_residue_indices))
paired_sorted = sorted(paired, key=lambda x: x[0], reverse=True)

top3 = [paired_sorted[i] for i in range(min(len(paired_sorted), 3))]

for top in top3:
    cur_res = "select :"
    for res in top[-1]:
        cur_res = f"{cur_res}{res},"
    print(f"{top[0]} - {cur_res[:-1]}")
    
first_set = set(x for x in top3[0][-1])
for top in top3:
    cur_res = "select :"
    cur_set = cur_set = set(x for x in top[-1])
    cur_set = cur_set - first_set
    for res in list(cur_set):
        cur_res = f"{cur_res}{res},"
    print(cur_res[:-1])

Some weights of the model checkpoint at ../prot_bert_allosteric were not used when initializing BertModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at ../prot_bert_allosteric and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum leng

0.04398220777511597 - select :78,16,79,17,13,21,71,75,22,18,20
0.00419106287881732 - select :63,62,8,86,60,88
0.0014705873327329755 - select :59,66,34,57,67,65,90,58,33,35
select :
select :63,62,8,86,60,88,
select :59,66,34,57,67,65,90,58,33,35,


In [1]:
import gc
import torch
import requests
import glob, os, math
import numpy as np
from transformers import BertModel, BertTokenizer
import xgboost as xgb

from utils.extract_sequence import extract_sequence
from utils.pocket_feature import pocket_feature
from utils.sequence_indices import sequence_indices
from utils.pocket_coordinates import pocket_coordinates

N_ATOMS = 9
MODEL_PATH = "../prot_bert_allosteric"
base_url = "https://files.rcsb.org/download"
pdb_dir = "../data/pdbs/"
pocket_dir = "../data/pockets/"
pdb_id = "3HJ0"
chain_id = "A"

is_test = True

pdb_path = os.path.join(pdb_dir, f"{pdb_id}.pdb")
pocket_path = os.path.join(pocket_dir, f"{pdb_id}_out")

#### Test - begin ####
if is_test:
    ASD_path = "../data/source_data/ASD_Release_201909_AS.txt"

    asd = None
    with open(ASD_path, "r") as f:
        asd = f.readlines()

    mod_id, modulator, residues = None, None, None
    for line in asd[1:]:
        line = line.strip().split("\t")
        pdb, modulator, chain_id, mod_id = line[4], line[6], line[7], line[11]

        if pdb != pdb_id:
            continue

        if len(set(chain_id.split(";"))) != 1:
            continue
        chain_id = chain_id[0]

        if len(set(modulator.split(";"))) != 1:
            continue
        modulator = modulator.split(";")[0]

        # extract residues
        res_raw = [
            res.replace(":", ",").split(",") for res in line[-1].split("; ")
        ]
        # residue_clean format: chain id + residue type + residue number
        residues = [
            [res[0][-1], ch[:3], ch[3:]] for res in res_raw for ch in res[1:]
        ]
        # select only residues in the same chain of modulator
        residues = [res for res in residues if res[0] == chain_id]

        break
#### Test - end ####


if not os.path.exists(pdb_path):
    response = requests.get(f"{base_url}/{pdb_id}.pdb")
    if response.status_code == 200:  # Check if the request was successful
        with open(pdb_path, 'wb') as file:
            file.write(response.content)
        print(f"PDB file {pdb_id}.pdb downloaded successfully.")
    else:
        raise Exception(f"Failed to download {pdb_id}.pdb. Check if the PDB ID is correct.")

sequence = extract_sequence(pdb_path, chain_id)

if len(sequence) <= 10:
    raise Exception("Sequence is too short.")

if not os.path.exists(pocket_path):
    os.system(f"fpocket -f {pdb_path} -k {chain_id}")
    os.system(f"mv {os.path.join(pdb_dir, pdb_id)}_out {pocket_dir}")

#### Test - begin ####
if is_test:
    protein = None
    lig_x, lig_y, lig_z, lig_cnt = 0, 0, 0, 0

    with open(pdb_path, "r") as f:
        protein = f.readlines()

    for line in protein:
        if (
            line[:6] == "HETATM" and modulator == line[17:20].strip()
            and line[21] == chain_id and mod_id == line[22:26].strip()
        ):
            lig_x += float(line[30:38])
            lig_y += float(line[38:46])
            lig_z += float(line[46:54])
            lig_cnt += 1

    lig_x /= lig_cnt
    lig_y /= lig_cnt
    lig_z /= lig_cnt
#### Test - end ####

pocket_names = glob.glob(f"{pocket_path}/pockets/*.pdb")
pocket_names = sorted(
    pocket_names,
    key=lambda x: int(x.split("pocket")[-1].split("_")[0])
)

pockets_feats = pocket_feature(f"{pocket_path}/{pdb_id}_info.txt")
selected_idxs = []
pocket_residue_indices = []

#### Test - begin ####
if is_test:
    atomTarget = {}
    for res in residues:
        atomTarget[f'{res[1]}{res[2]}'] = res[0]

    dists = []
    countsPockets = [] # for atom count
#### Test - end ####

for idx, pocket_name in enumerate(pocket_names):
    pocket = None
    with open(pocket_name, "r") as f:
        pocket = f.readlines()

#### Test - begin ####
    if is_test:
        poc_x, poc_y, poc_z = 0, 0, 0
        pocketAtomCount = 0
#### Test - end ####

    poc_cnt = 0
    residue_indices = set()

    for line in pocket:
        if line[:4] == "ATOM":
            poc_cnt += 1
            residue_index = line[22:26].strip()
            atom = line[17:20] + residue_index
            residue_indices.add(residue_index)

#### Test - begin ####
            if is_test:
                poc_x += float(line[30:38])
                poc_y += float(line[38:46])
                poc_z += float(line[46:54])
                chainID = line[21]
                if atom in atomTarget and atomTarget[atom] == chainID:
                    pocketAtomCount += 1
#### Test - end ####

    if poc_cnt == 0:
        continue

#### Test - begin ####
    if is_test:
        poc_x /= poc_cnt
        poc_y /= poc_cnt
        poc_z /= poc_cnt
        dist = math.sqrt(
            (poc_x - lig_x) ** 2 + (poc_y - lig_y) ** 2 +
            (poc_z - lig_z) ** 2
        )

        dists.append(dist)
        countsPockets.append(pocketAtomCount)
#### Test - end ####

    selected_idxs.append(idx)
    pocket_residue_indices.append(list(residue_indices))

if len(selected_idxs) <= 2:
    raise Exception("Too few pockets extracted.")

pocket_features = [pockets_feats[idx] for idx in selected_idxs]

seq_indices = sequence_indices(pdb_id, chain_id)

#### Test - begin ####
if is_test:
    dist_min_idx = np.argmin(dists)
    labels = [1 if item >= N_ATOMS else 0 for item in countsPockets] # for atom count
    labels[dist_min_idx] = 1

    seq_labels = ['N'] * len(sequence)
    for i in range(len(labels)):
            if labels[i] == 1:
                for residue_index in pocket_residue_indices[i]:
                    if residue_index in seq_indices and seq_indices[residue_index] < len(sequence):
                        seq_labels[seq_indices[residue_index]] = 'Y'
#### Test - end ####

pocket_coord = pocket_coordinates(pdb_path, f"{pocket_path}/pockets/", pdb_id, chain_id, pocket_residue_indices)

tokenizer = BertTokenizer.from_pretrained(MODEL_PATH, do_lower_case=False )
model = BertModel.from_pretrained(MODEL_PATH)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model = model.eval()

seq_emb = None
poc_res_emb = []

#### Test - begin ####
if is_test:
    poc_labels = []
#### Test - end ####

with torch.no_grad():
    seq = " ".join(sequence)
    encoding = tokenizer.batch_encode_plus(
        [seq],
        add_special_tokens=True,
        padding='max_length'
    )
    input_ids = torch.tensor(encoding['input_ids']).to(device)
    attention_mask = torch.tensor(encoding['attention_mask']).to(device)
    embedding = model(input_ids=input_ids, attention_mask=attention_mask)[0].cpu().numpy()

    seq_len = (attention_mask[0] == 1).sum()
    token_emb = embedding[0][1:seq_len-1]
    seq_emb = token_emb.mean(axis=0)

    for i in range(len(pocket_residue_indices)):
        add_pocket = True
        cur_poc_emb = []

#### Test - begin ####
        if is_test:
            poc_labels.append(labels[i])
#### Test - end ####

        for idx in pocket_residue_indices[i]:
            try:
                token = token_emb[seq_indices[idx]]
                cur_poc_emb.append(token)
            except Exception as e:
                add_pocket = False
#### Test - begin ####
                if is_test:
                    poc_labels.pop()
#### Test - end ####
                break
        
        if add_pocket:
            poc_res_emb.append(cur_poc_emb)

del model
torch.cuda.empty_cache()
gc.collect()

def get_res_data(poc_res_emb, pocket_coord):
    X = []
    Y = []

    for i in range(min(len(poc_res_emb), len(pocket_coord))):
        seq_emb = []
        for res_idx in range(min(len(poc_res_emb[i]), len(pocket_coord[i]))):
            seq_emb.append(poc_res_emb[i][res_idx])
        seq_emb = np.array(seq_emb).mean(axis=0)
        poc = pocket_features[i]
        X.append(np.concatenate((seq_emb, poc)))
#### Test - begin ####
        if is_test:
            Y.append(labels[i])
#### Test - end ####

    if is_test:
        return X, Y
    else:
        return X

if is_test:
    X_Test, Y_Test = get_res_data(poc_res_emb, pocket_coord)
    dtest = xgb.DMatrix(X_Test, label=Y_Test)
else:
    X_Test = get_res_data(poc_res_emb, pocket_coord)
    dtest = xgb.DMatrix(X_Test)
bst = xgb.Booster()
bst.load_model('./xgboost.model')
y_pred = bst.predict(dtest)

if is_test:
    paired = list(zip(y_pred, Y_Test, pocket_residue_indices))
else:
    paired = list(zip(y_pred, pocket_residue_indices))
paired_sorted = sorted(paired, key=lambda x: x[0], reverse=True)

top3 = [paired_sorted[i] for i in range(min(len(paired_sorted), 3))]

for top in top3:
    cur_res = "select :"
    for res in top[-1]:
        cur_res = f"{cur_res}{res},"
    print(f"{top[0]} - {top[1]} - {cur_res[:-1]}")
    
first_set = set(x for x in top3[0][-1])
for top in top3:
    cur_res = "select :"
    cur_set = cur_set = set(x for x in top[-1])
    cur_set = cur_set - first_set
    for res in list(cur_set):
        cur_res = f"{cur_res}{res},"
    print(cur_res[:-1])

  from .autonotebook import tqdm as notebook_tqdm
Some weights of the model checkpoint at ../prot_bert_allosteric were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at ../prot_bert_allosteric and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Asking to pad to max_length but no maximum length is pro

0.9999921321868896 - 1 - select :119,109,15,17,117,108,110
0.143299400806427 - 0 - select :18,22,17,24,23,110,19
0.0005511316703632474 - 0 - select :115,79,116,112,86,113,111,88,75
select 
select :18,22,24,23,19
select :115,79,116,112,86,113,111,88,75
