In [2]:
tankbind_src_folder_path = "./tankbind/"
import sys
sys.path.insert(0, tankbind_src_folder_path)

In [3]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm

In [4]:
df = pd.read_csv('Mcule_10000.csv', index_col=0)

In [5]:
from feature_utils import get_protein_feature, get_clean_res_list
from Bio.PDB import PDBParser

In [22]:
def extract_protein_names(file_path):
    protein_names = []
    with open(file_path, 'r') as file:
        for line in file:
            # Get the base name without directory path
            base_name = os.path.basename(line.strip())
            # Split the base name to get the protein name without extension
            protein_name = os.path.splitext(base_name)[0]
            protein_names.append(protein_name)
    return protein_names

file_path = 'pdb_dataset.ds'
protein_names = extract_protein_names(file_path)
print(protein_names[0:5])

['1A07', '1A08', '1A09', '1A0N', '1A1A']


In [23]:
protein_names.__len__()

5963

In [24]:
import random
# n = len(protein_names) // 5
n = 596
sampled_protein_names = random.sample(protein_names, n)
print(sampled_protein_names[0:15], len(sampled_protein_names))

['2R3G', '1P4O', '8AO3', '6W4P', '3OJM', '8D7P', '6OMU', '6HOP', '6P5M', '3FKO', '4BBF', '7QB2', '2F4J', '3EZR', '6YYG'] 596


In [25]:
def process_proteins(protein_names, pdb_directory):
    parser = PDBParser(QUIET=True)
    protein_dict = {}
    i = 0
    for proteinName in protein_names:
        try:
            # print(i, ': ', proteinName)
            proteinFile = f"{pdb_directory}/{proteinName}.pdb"
            s = parser.get_structure(proteinName, proteinFile)
            res_list = list(s.get_residues())
            clean_res_list = get_clean_res_list(res_list, ensure_ca_exist=True)
            protein_dict[proteinName] = get_protein_feature(clean_res_list)
        except Exception as e:
            print(f"Error processing {proteinName}: {e}")
        finally:
            i += 1
    
    return protein_dict

pdb_directory = "PDB_files"  # Directory containing PDB files

protein_dict = process_proteins(sampled_protein_names, pdb_directory)

In [26]:

protein_dict.keys().__len__() # doesn't work for 2 proteins 1IAN & 6XR4 (indexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)) --> Cause: likely different format of PDB file

596

In [11]:
# import torch
# torch.save(protein_dict, 'precalculated_protein_features_small.pt')

In [27]:
protein_names_new = list(protein_dict.keys())
protein_names_new.__len__()

596

In [28]:
info = []
for protein_name in protein_names_new:
    for i, line in tqdm(df.iterrows(), total=len(df)):
        smiles = line['smiles']
        compund_name = ''
        protein_name = protein_name
        com = ",".join([str(x.round(3)) for x in protein_dict[protein_name][0].mean(axis=0).numpy()])
        info.append([protein_name, compund_name, smiles, "protein_center", com])

info = pd.DataFrame(info, columns=['protein_name', 'compound_name', 'smiles', 'pocket_name', 'pocket_com'])


100%|██████████| 10000/10000 [00:01<00:00, 7407.95it/s]
100%|██████████| 10000/10000 [00:01<00:00, 7687.17it/s]
100%|██████████| 10000/10000 [00:01<00:00, 7659.70it/s]
100%|██████████| 10000/10000 [00:01<00:00, 7119.67it/s]
100%|██████████| 10000/10000 [00:01<00:00, 7673.28it/s]
100%|██████████| 10000/10000 [00:01<00:00, 7443.78it/s]
100%|██████████| 10000/10000 [00:01<00:00, 7605.66it/s]
100%|██████████| 10000/10000 [00:01<00:00, 7445.36it/s]
100%|██████████| 10000/10000 [00:01<00:00, 7696.93it/s]
100%|██████████| 10000/10000 [00:01<00:00, 6651.87it/s]
100%|██████████| 10000/10000 [00:01<00:00, 7430.36it/s]
100%|██████████| 10000/10000 [00:01<00:00, 8008.47it/s]
100%|██████████| 10000/10000 [00:01<00:00, 7880.91it/s]
100%|██████████| 10000/10000 [00:01<00:00, 7912.13it/s]
100%|██████████| 10000/10000 [00:01<00:00, 7497.69it/s]
100%|██████████| 10000/10000 [00:01<00:00, 7605.36it/s]
100%|██████████| 10000/10000 [00:01<00:00, 6106.46it/s]
100%|██████████| 10000/10000 [00:01<00:00, 7763.

In [29]:
info.__len__()

5960000

In [30]:
info

Unnamed: 0,protein_name,compound_name,smiles,pocket_name,pocket_com
0,2R3G,,CC1=C(C2=CC=C(F)C=C2)N=C(NC(C2=C(C)C3=C(CCCC3=...,protein_center,"1.924,29.866,21.385"
1,2R3G,,CC(N(C(NC1C2OC(CC2)C1)=O)C)CS(C)(=O)=O,protein_center,"1.924,29.866,21.385"
2,2R3G,,S(=O)(=O)(C1C=CC2=C(C=1)N=C(N2C1CCCC1)C)N1CCC(...,protein_center,"1.924,29.866,21.385"
3,2R3G,,N1CCCCC1CC1=CC=NC=C1,protein_center,"1.924,29.866,21.385"
4,2R3G,,N1(CCC2C=CC(CNC(=O)NC3=CC=CC=C3C)=CC1=2)C(C1=C...,protein_center,"1.924,29.866,21.385"
...,...,...,...,...,...
5959995,3TN8,,CCCN(C(OC(C)(C)C)=O)CCNC(C1=C(SC)C=CC=C1)=O,protein_center,"36.013,-8.1,-10.609"
5959996,3TN8,,N[C@H](C1C=NC=CC=1Br)C(F)(F)F,protein_center,"36.013,-8.1,-10.609"
5959997,3TN8,,C12N=CN(C(=O)C=1N=NN2CC1C=CC(=CC=1)F)CC(=O)NC1...,protein_center,"36.013,-8.1,-10.609"
5959998,3TN8,,N1(CCC[C@H](O)[C@@H]1CC1=CC=CC=C1F)CC1C=CC=CC=...,protein_center,"36.013,-8.1,-10.609"


### Construct the Dataset

In [6]:
import torch
torch.set_num_threads(1)

In [7]:
from torch_geometric.data import Dataset
from utils import construct_data_from_graph_gvp
import rdkit.Chem as Chem    # conda install rdkit -c rdkit if import failure.
from feature_utils import extract_torchdrug_feature_from_mol, get_canonical_smiles

In [9]:
class MyDataset_VS(Dataset):
    def __init__(self, root, data=None, protein_dict=None, precomputed_mol_dict=None, proteinMode=0, compoundMode=1,
                 pocket_radius=20, shake_nodes=None,
                 transform=None, pre_transform=None, pre_filter=None):
        self.data = data
        self.protein_dict = protein_dict
        self.precomputed_mol_dict = precomputed_mol_dict
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data = torch.load(self.processed_paths[0])
        self.protein_dict = torch.load(self.processed_paths[1])
        self.precomputed_mol_dict = torch.load(self.processed_paths[2])
        self.proteinMode = proteinMode
        self.pocket_radius = pocket_radius
        self.compoundMode = compoundMode
        self.shake_nodes = shake_nodes

    @property
    def processed_file_names(self):
        return ['data.pt', 'proteins.pt', 'molecules.pt']

    def process(self):
        # Save data and protein dictionary
        torch.save(self.data, self.processed_paths[0])
        torch.save(self.protein_dict, self.processed_paths[1])
        torch.save(self.precomputed_mol_dict, self.processed_paths[2])

    def len(self):
        return len(self.data)

    def get(self, idx):
        line = self.data.iloc[idx]
        smiles = line['smiles']
        pocket_com = line['pocket_com']
        pocket_com = np.array(pocket_com.split(",")).astype(float) if isinstance(pocket_com, str) else pocket_com
        pocket_com = pocket_com.reshape((1, 3))
        use_whole_protein = line.get('use_whole_protein', False)

        protein_name = line['protein_name']
        protein_data = self.protein_dict.get(protein_name)
        
        if protein_data is None:
            raise ValueError(f"Protein {protein_name} not found in pre-calculated protein dictionary")

        protein_node_xyz, protein_seq, protein_node_s, protein_node_v, protein_edge_index, protein_edge_s, protein_edge_v = protein_data

        # Load precomputed molecular features
        molecule_data = self.precomputed_mol_dict.get(smiles)
        if molecule_data is None:
            raise ValueError(f"SMILES {smiles} not found in precomputed molecular dictionary")
        
        coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list, pair_dis_distribution = self.precomputed_mol_dict[smiles]

        data, input_node_list, keepNode = construct_data_from_graph_gvp(
            protein_node_xyz, protein_seq, protein_node_s, protein_node_v, 
            protein_edge_index, protein_edge_s, protein_edge_v,
            coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list,
            pocket_radius=self.pocket_radius, use_whole_protein=use_whole_protein, includeDisMap=True,
            use_compound_com_as_pocket=False, chosen_pocket_com=pocket_com, compoundMode=self.compoundMode
        )
        data.compound_pair = pair_dis_distribution.reshape(-1, 16)
        
        return data


In [11]:
# create dataset dirrectory
dataset_path = f"dataset_10"
os.system(f"rm -r {dataset_path}")
os.system(f"mkdir -p {dataset_path}")

In [20]:
# # Load the dataset
# dataset_path = f"dataset"
# os.system(f"rm -r {dataset_path}")
# os.system(f"mkdir -p {dataset_path}")
# # dataset = MyDataset_VS(root=dataset_path, data=info, protein_dict='precalculated_protein_features_small.pt')

# dataset = MyDataset_VS(root=dataset_path, data=info, protein_dict=protein_dict)


Processing...
Done!


In [35]:
smiles_list = info['smiles'].to_list()
smiles_set = set(smiles_list)

In [36]:
molecule_dict = {}

for molecule in smiles_set:

    smiles = get_canonical_smiles(molecule)

    mol = Chem.MolFromSmiles(smiles)

    mol.Compute2DCoords()

    molecule_dict[molecule] = extract_torchdrug_feature_from_mol(mol, has_LAS_mask=True)



In [39]:
os.makedirs('dataset_10/processed')

torch.save(molecule_dict, 'dataset_10/processed/molecules.pt')

### Load Files for Dataset Creation

In [52]:
# if available, load the processed files
data = torch.load('dataset/processed/data.pt')
protein_dict = torch.load('dataset/processed/proteins.pt')
molecules = torch.load('dataset/processed/molecules.pt')

In [10]:
# Load the dataset
dataset_path = f"dataset_10"

# change to this line if no precomputed data/protein_dict files available
# dataset = MyDataset_VS(root=dataset_path, data=info, protein_dict=protein_dict, precomputed_mol_dict=molecule_dict)
dataset = MyDataset_VS(root=dataset_path)

In [68]:
updated_dataset = dataset[860290:]

In [69]:
updated_dataset.__len__()

5099710

### Model Testing (only for 5 proteins, very compute intensive, for more use google collab)

In [27]:
import logging
from torch_geometric.loader import DataLoader
from tqdm import tqdm    # pip install tqdm if fails.
from model import get_model

#### Filter dataset (only use 10% and remove large proteins/pockets)

In [16]:
mini_try = dataset[0:5690000]
mini_try.__len__()

5690000

In [None]:
batch_size = 5 # higher batchsize possible only if enough memmory is available (eg.: 10)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device= 'cpu'
logging.basicConfig(level=logging.INFO)
model = get_model(0, logging, device)

# self-dock model
modelFile = "./model/self_dock.pt"

model.load_state_dict(torch.load(modelFile, map_location=device))
_ = model.eval()

data_loader = DataLoader(updated_dataset, batch_size=batch_size, follow_batch=['x', 'y', 'compound_pair'], shuffle=False, num_workers=0)
affinity_pred_list_min = []
y_pred_list = []
for data in tqdm(data_loader):
    print(data.dis_map.shape[0])
    if data.dis_map.shape[0] < 50000:
        data = data.to(device)
        y_pred, affinity_pred = model(data)
        affinity_pred_list_min.append(affinity_pred.detach().cpu())
    else:
        print('to big')
        affinity_pred_list_min.append(torch.zeros(5, 1))

affinity_pred_list_min = torch.cat(affinity_pred_list_min)

### Check if further code works with small mini_try dataset 

In [11]:
new_aff_list = torch.load('affinity_pred_list_2.pt')

In [18]:
reshaped_aff_list = [tensor.view(5) for tensor in new_aff_list]

In [20]:
reshaped_aff_list_cat = torch.cat(reshaped_aff_list)

In [22]:
info_2 = dataset.data[340000:860290]
info_2['affinity'] = reshaped_aff_list_cat

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  info_2['affinity'] = reshaped_aff_list_cat


In [23]:
info_2

Unnamed: 0,protein_name,compound_name,smiles,pocket_name,pocket_com,affinity
340000,1LL8,,CC1=C(C2=CC=C(F)C=C2)N=C(NC(C2=C(C)C3=C(CCCC3=...,protein_center,"-0.571,-0.426,0.12",0.0
340001,1LL8,,CC(N(C(NC1C2OC(CC2)C1)=O)C)CS(C)(=O)=O,protein_center,"-0.571,-0.426,0.12",0.0
340002,1LL8,,S(=O)(=O)(C1C=CC2=C(C=1)N=C(N2C1CCCC1)C)N1CCC(...,protein_center,"-0.571,-0.426,0.12",0.0
340003,1LL8,,N1CCCCC1CC1=CC=NC=C1,protein_center,"-0.571,-0.426,0.12",0.0
340004,1LL8,,N1(CCC2C=CC(CNC(=O)NC3=CC=CC=C3C)=CC1=2)C(C1=C...,protein_center,"-0.571,-0.426,0.12",0.0
...,...,...,...,...,...,...
860285,6PTS,,N1(C=CC=C1)C(CC(=O)NCC1CCCN1CC)C1C=CSC=1,protein_center,"12.48,-2.579,0.324",0.0
860286,6PTS,,C1(CN(CCN2CCOCC2)C(=S)NC2=CC=CC(Cl)=C2)=CC2C(=...,protein_center,"12.48,-2.579,0.324",0.0
860287,6PTS,,C1(=O)N(C)C=NC2C=C(C=CC1=2)C(=O)N1CCCC(C1)C1=N...,protein_center,"12.48,-2.579,0.324",0.0
860288,6PTS,,N12N=NC=C1CO[C@@H]1CCN(CC3C=CC4C(=CC=C(F)C=4)N...,protein_center,"12.48,-2.579,0.324",0.0


In [24]:
info_2.to_csv('dataset_10/result_info_2.csv')

In [17]:
affinity_pred_list_min_cat = torch.cat(affinity_pred_list)

In [18]:
affinity_pred_list_min_cat.__len__()

30

In [19]:
info_mini = dataset.data.iloc[:30]
info_mini['affinity'] = affinity_pred_list_min_cat

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  info_mini['affinity'] = affinity_pred_list_min_cat


In [20]:
info_mini

Unnamed: 0,protein_name,compound_name,smiles,pocket_name,pocket_com,affinity
0,2R3G,,CC1=C(C2=CC=C(F)C=C2)N=C(NC(C2=C(C)C3=C(CCCC3=...,protein_center,"1.924,29.866,21.385",7.350928
1,2R3G,,CC(N(C(NC1C2OC(CC2)C1)=O)C)CS(C)(=O)=O,protein_center,"1.924,29.866,21.385",5.541882
2,2R3G,,S(=O)(=O)(C1C=CC2=C(C=1)N=C(N2C1CCCC1)C)N1CCC(...,protein_center,"1.924,29.866,21.385",7.370193
3,2R3G,,N1CCCCC1CC1=CC=NC=C1,protein_center,"1.924,29.866,21.385",4.777201
4,2R3G,,N1(CCC2C=CC(CNC(=O)NC3=CC=CC=C3C)=CC1=2)C(C1=C...,protein_center,"1.924,29.866,21.385",6.896791
5,2R3G,,N1(C(CCCC)C(=O)OCC(=O)C2C=CC(=CC=2)C)C(=O)C2=C...,protein_center,"1.924,29.866,21.385",6.10621
6,2R3G,,C1(C(OC)=O)C(C2=CC=CC(OC)=C2)NC(NC=1CN1CCC(CC1...,protein_center,"1.924,29.866,21.385",6.597138
7,2R3G,,COC(C(NC(CC1=CN(C)N=C1)=O)CC(C)C)=O,protein_center,"1.924,29.866,21.385",4.643211
8,2R3G,,N1(CCCC1C(=O)NC1C=CC=C(C=1)N1C=CC=N1)CC1=CC=CC...,protein_center,"1.924,29.866,21.385",6.915017
9,2R3G,,CN(C(C1=CC2=C(N=C(C=C2)C)C=C1)=O)C1CCN(CC2=CC=...,protein_center,"1.924,29.866,21.385",6.2585


In [21]:
chosen = info_mini.loc[info_mini.groupby('protein_name',sort=False)['affinity'].agg('idxmax')].reset_index()
chosen

Unnamed: 0,index,protein_name,compound_name,smiles,pocket_name,pocket_com,affinity
0,2,2R3G,,S(=O)(=O)(C1C=CC2=C(C=1)N=C(N2C1CCCC1)C)N1CCC(...,protein_center,"1.924,29.866,21.385",7.370193


In [34]:
from generation_utils import get_LAS_distance_constraint_mask, get_info_pred_distance, write_with_new_coords


In [27]:
# pick one with affinity greater than 7.
chosen = info_mini.loc[info_mini.groupby(['protein_name', 'smiles'],sort=False)['affinity'].agg('idxmax')].reset_index()
chosen = chosen.query("affinity > 9").reset_index(drop=True)
line = chosen.iloc[1]
idx = line['index']


In [60]:
selected_entries_dataset = MyDataset_VS(root=dataset_path, data=chosen, protein_dict=protein_dict, precomputed_mol_dict=molecule_dict) 


In [73]:
selected_entries_dataset.data = chosen
selected_entries_dataset.__len__()

1

In [10]:
info = pd.read_csv('dataset_10/result_info.csv')

chosen = info.loc[info.groupby('protein_name',sort=False)['affinity'].agg('idxmax')].reset_index()

chosen_data = chosen
chosen_idxs = chosen_data['index'].to_list()
chosen_smiles = chosen_data['smiles'].tolist()

dataset_path = f"dataset_10"

selected_entries_dataset = MyDataset_VS(root=dataset_path, data=chosen_data) 
selected_entries_dataset.data = chosen_data

print(selected_entries_dataset.__len__())


34


In [30]:
# info_2 = pd.read_csv('dataset_10/result_info_2.csv')

chosen_2 = info_2.loc[info_2.groupby('protein_name',sort=False)['affinity'].agg('idxmax')].reset_index()

chosen_data_2 = chosen_2[chosen_2['affinity'] != 0]
# chosen_idxs_2 = chosen_data_2['index'].to_list()
chosen_smiles_2 = chosen_data_2['smiles'].tolist()

dataset_path = f"dataset_10"

selected_entries_dataset_2 = MyDataset_VS(root=dataset_path, data=chosen_data_2) 
selected_entries_dataset_2.data = chosen_data_2

print(selected_entries_dataset_2.__len__())

51


In [35]:
# Create a DataLoader for the entire mini_try dataset
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device= 'cpu'
logging.basicConfig(level=logging.INFO)
model = get_model(0, logging, device)

# self-dock model
modelFile = "./model/self_dock.pt"

model.load_state_dict(torch.load(modelFile, map_location=device))
_ = model.eval()

data_loader = DataLoader(selected_entries_dataset_2, batch_size=1, follow_batch=['x', 'y', 'compound_pair'], shuffle=False, num_workers=0)

y_preds = []
tankbind_list = []

for i, data_with_batch_info in enumerate(tqdm(data_loader)):

    y_pred, affinity_pred = model(data_with_batch_info)

    coords = data_with_batch_info.coords.to(device)
    protein_nodes_xyz = data_with_batch_info.node_xyz.to(device)
    n_compound = coords.shape[0]  # Assuming coords is a batch of coordinates
    n_protein = protein_nodes_xyz.shape[0]  # Assuming node_xyz is a batch of node positions
    y_pred = y_pred.reshape(n_protein, n_compound).to(device).detach()
    y_preds.append(y_pred)
    y = data_with_batch_info.dis_map.reshape(n_protein, n_compound).to(device)  # Assuming dis_map is available in the batch
    compound_pair_dis_constraint = torch.cdist(coords, coords)

    # Handle the batch results as needed
    # Extract SMILES and generate 2D coordinates
    smiles = chosen_smiles_2[i]  # Assuming batch size is 1
    mol = Chem.MolFromSmiles(smiles)
    mol.Compute2DCoords()

    # Compute LAS distance constraint mask
    LAS_distance_constraint_mask = get_LAS_distance_constraint_mask(mol).bool()

    # Calculate information
    info = get_info_pred_distance(coords, y_pred, protein_nodes_xyz, compound_pair_dis_constraint,
                                  LAS_distance_constraint_mask=LAS_distance_constraint_mask,
                                  n_repeat=1, show_progress=False)

    # Save to file
    # toFile = f'KIBA_tankbind_{i}.sdf'  # Unique file name for each iteration
    new_coords = info.sort_values("loss")['coords'].iloc[0].astype(np.double)
    # write_with_new_coords(mol, new_coords, toFile)
    
    tankbind_list.append([mol, new_coords])


11:21:02   5 stack, readout2, pred dis map add self attention and GVP embed, compound model GIN


100%|██████████| 51/51 [05:17<00:00,  6.23s/it]


In [36]:
output_file = 'KIBA_tankbind_2.sdf'

# Write each molecule with its new coordinates to a temporary file and append to the final SDF file
with open(output_file, 'w') as final_sdf:
    for i, (mol, new_coords) in enumerate(tankbind_list):
        temp_file = f'temp_{i}.sdf'
        write_with_new_coords(mol, new_coords, temp_file)
        with open(temp_file, 'r') as temp_sdf:
            final_sdf.write(temp_sdf.read())
        os.remove(temp_file)

In [37]:
tankbind_list

[[<rdkit.Chem.rdchem.Mol at 0x223153c5ac0>,
  array([[-39.99023438, -43.70363235,  27.282547  ],
         [-40.02205276, -44.35167694,  25.94921303],
         [-39.06705475, -45.48439026,  25.68947983],
         [-39.01185226, -46.12120819,  24.33093643],
         [-37.80017471, -46.70598602,  23.70432854],
         [-37.36889267, -48.12618637,  23.81985855],
         [-38.3865242 , -49.2115593 ,  24.18803596],
         [-36.09908676, -48.31437683,  24.80363846],
         [-36.62612152, -47.02073288,  25.21534157],
         [-36.10723114, -46.55240631,  26.62067413],
         [-34.57904053, -46.94635391,  26.80072784],
         [-34.09067917, -47.29528809,  28.16699219],
         [-34.97949982, -47.05894852,  29.36775208],
         [-34.40142441, -47.27877045,  30.74562645],
         [-33.64228821, -48.51580048,  31.05593681],
         [-36.4438324 , -47.10255814,  29.15329552],
         [-36.48032379, -45.3850441 ,  30.63771057],
         [-36.77907944, -45.17002106,  32.10346985],
  

In [52]:
selected_entries_dataset_2[0].y

tensor([0., 0., 0.,  ..., 0., 0., 0.])

In [46]:
tankbind_list[0][1].shape

(35, 3)

In [48]:
y_preds[0].shape

torch.Size([176, 35])

In [55]:
from utils import evaulate
criterion = torch.nn.MSELoss()
eval_res = evaulate(data_loader, model, criterion, device)

AttributeError: 'tuple' object has no attribute 'size'

In [39]:
mol = Chem.MolFromSmiles(smiles)
mol.Compute2DCoords()
LAS_distance_constraint_mask = get_LAS_distance_constraint_mask(mol).bool()

In [41]:
chosen_info = get_info_pred_distance(coords, y_pred, protein_nodes_xyz, compound_pair_dis_constraint, 
                              LAS_distance_constraint_mask=LAS_distance_constraint_mask,
                              n_repeat=1, show_progress=False)

In [42]:
chosen_info

Unnamed: 0,repeat,rmsd,loss,coords
0,0,86.821625,2683.349121,"[[26.3994, 83.02271, 4.1265006], [26.461185, 8..."


In [44]:
new_coords = chosen_info.sort_values("loss")['coords'].iloc[0].astype(np.double)
new_coords

array([[26.39940071, 83.02271271,  4.12650061],
       [26.46118546, 84.20759583,  3.26819062],
       [26.80811501, 83.99140167,  1.75630832],
       [25.73953819, 84.19969177,  0.74874777],
       [24.26147652, 84.01294708,  1.17117512],
       [24.27444839, 85.58414459,  2.07344699],
       [22.80819511, 85.05859375,  2.00607419],
       [22.20603752, 84.86125946,  3.40974355],
       [20.76871872, 84.44417572,  3.51491451],
       [19.86489487, 84.90786743,  4.57274055],
       [18.37797737, 84.74586487,  4.53427601],
       [17.88718224, 83.58028412,  3.5690589 ],
       [16.97315025, 83.02701569, -1.05461371],
       [17.72415924, 82.39987946, -2.34311652],
       [19.15843582, 82.52580261, -1.87686586],
       [19.07336426, 82.67671967, -0.35581687],
       [17.89111519, 81.63816833,  3.98576665],
       [17.82991791, 82.72317505,  5.0145421 ],
       [16.74136734, 83.4278183 ,  5.77836514],
       [17.02000809, 84.85614014,  5.87290239],
       [17.17098236, 83.88796997,  6.999