In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import yaml
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import copy
import pickle
from seaborn import histplot

In [4]:
os.chdir('/data/rsg/nlp/sdobers/amine/diffdock-protein/src')

from args import parse_args
from data import load_data, get_data
from data.data import BindingDataset
from model import load_model, to_cuda
from utils import printt, print_res, log, get_unixtime, compute_rmsd
from train import train, evaluate, evaluate_pose
from helpers import WandbLogger, TensorboardLogger
from sample import sample
from evaluation.compute_rmsd import evaluate_all_rmsds

In [5]:
from notebooks.utils_notebooks import Dict2Class

In [6]:
PATH = '/data/rsg/nlp/sdobers/amine/diffdock-protein/ckpts/dips_largest_model' # largest model

In [15]:
# load args
with open(os.path.join(PATH, 'args.yaml')) as f:
    args = yaml.safe_load(f)
args = Dict2Class(args)

args.num_gpu = 1
args.gpu = 5
args.data_file = args.data_file.replace('data_file', 'data_file_100_test')
args.checkpoint_path = PATH
#args.use_orientation_features = False
#args.recache = False
args.batch_size = 32
args.use_randomized_confidence_data = False
args.mode = "test"
if 'large' in PATH:
    args.batch_size = 4
#args.cross_cutoff_weight = 3
#args.cross_cutoff_bias = 40

In [16]:
args.data_file

'/data/rsg/nlp/sdobers/data/DIPS/data_file_100_test.csv'

In [29]:
data = load_data(args)
loaders = get_data(data, 0, args, for_reverse_diffusion=True)

data loading: 100%|█| 100/100 [00:00<00:00, 465516


05:28:20 Loaded cached ESM embeddings
05:28:20 finished tokenizing residues with ESM
05:28:20 finished tokenizing all inputs
05:28:20 100 entries loaded


In [37]:
loaders["test"].data = sorted(loaders["test"].data, key=lambda x:x['receptor_xyz'].shape[0] + x['ligand_xyz'].shape[0])


In [40]:
[comp["receptor"].pos.shape[0] + comp["ligand"].pos.shape[0] for comp in loaders["test"]]

[122,
 126,
 156,
 159,
 193,
 207,
 214,
 219,
 220,
 222,
 225,
 234,
 235,
 237,
 248,
 255,
 258,
 259,
 264,
 265,
 278,
 280,
 280,
 286,
 296,
 298,
 298,
 300,
 300,
 302,
 306,
 308,
 310,
 311,
 320,
 329,
 334,
 338,
 340,
 344,
 393,
 395,
 427,
 428,
 431,
 431,
 432,
 443,
 443,
 446,
 452,
 506,
 518,
 520,
 522,
 523,
 523,
 526,
 530,
 530,
 533,
 562,
 562,
 597,
 604,
 605,
 605,
 608,
 621,
 624,
 626,
 638,
 639,
 659,
 660,
 678,
 680,
 760,
 776,
 818,
 818,
 840,
 880,
 924,
 928,
 928,
 960,
 1019,
 1024,
 1052,
 1074,
 1074,
 1143,
 1178,
 1339,
 1436,
 1454,
 1458,
 1474,
 1474]

In [42]:
loaders["test"][0:64]

BindingDataset(64)

In [43]:
[2,3,4] + [5,7,8]

[2, 3, 4, 5, 7, 8]

In [26]:
loaders["test"] = sorted(loaders["test"], key=lambda x:x["receptor"].pos.shape[0] + x["ligand"].pos.shape[0])

In [28]:
loaders["test"]

[HeteroData(
   name='kq/1kq1.pdb1_2.dill',
   center=[1, 3],
   [1mreceptor[0m={
     pos=[61, 3],
     x=[61, 1281]
   },
   [1mligand[0m={
     pos=[61, 3],
     x=[61, 1281]
   },
   [1m(receptor, contact, receptor)[0m={ edge_index=[2, 1220] },
   [1m(ligand, contact, ligand)[0m={ edge_index=[2, 1220] }
 ),
 HeteroData(
   name='kq/1kq1.pdb1_1.dill',
   center=[1, 3],
   [1mreceptor[0m={
     pos=[66, 3],
     x=[66, 1281]
   },
   [1mligand[0m={
     pos=[60, 3],
     x=[60, 1281]
   },
   [1m(receptor, contact, receptor)[0m={ edge_index=[2, 1320] },
   [1m(ligand, contact, ligand)[0m={ edge_index=[2, 1200] }
 ),
 HeteroData(
   name='cf/5cff.pdb2_1.dill',
   center=[1, 3],
   [1mreceptor[0m={
     pos=[87, 3],
     x=[87, 1281]
   },
   [1mligand[0m={
     pos=[69, 3],
     x=[69, 1281]
   },
   [1m(receptor, contact, receptor)[0m={ edge_index=[2, 1740] },
   [1m(ligand, contact, ligand)[0m={ edge_index=[2, 1380] }
 ),
 HeteroData(
   name='dm/5dm7.pdb1_22

In [27]:
[comp["receptor"].pos.shape[0] + comp["ligand"].pos.shape[0] for comp in loaders["test"]]

[122,
 126,
 156,
 159,
 193,
 207,
 214,
 219,
 220,
 222,
 225,
 234,
 235,
 237,
 248,
 255,
 258,
 259,
 264,
 265,
 278,
 280,
 280,
 286,
 296,
 298,
 298,
 300,
 300,
 302,
 306,
 308,
 310,
 311,
 320,
 329,
 334,
 338,
 340,
 344,
 393,
 395,
 427,
 428,
 431,
 431,
 432,
 443,
 443,
 446,
 452,
 506,
 518,
 520,
 522,
 523,
 523,
 526,
 530,
 530,
 533,
 562,
 562,
 597,
 604,
 605,
 605,
 608,
 621,
 624,
 626,
 638,
 639,
 659,
 660,
 678,
 680,
 760,
 776,
 818,
 818,
 840,
 880,
 924,
 928,
 928,
 960,
 1019,
 1024,
 1052,
 1074,
 1074,
 1143,
 1178,
 1339,
 1436,
 1454,
 1458,
 1474,
 1474]