In [29]:
import torch
import os, sys
from argparse import ArgumentParser

sys.path.append('../')

from torch_geometric.loader import DataLoader

from utils.utils import get_model
from utils.parsing import parse_train_args
from model.flow import SE3VerletFlow
from datasets.pdbbind import PDBBind, InitializeVelocity

WORKDIR = '/data/scratch/erives/verlet_flows/'
BEST_MODEL = 'workdir/best_model.pt'
BEST_MODEL_PATH = os.path.join(WORKDIR, BEST_MODEL)

CACHE_PATH = os.path.join(WORKDIR, 'data/cache')
SPLIT_TRAIN_PATH = os.path.join(WORKDIR, 'data/splits/timesplit_no_lig_overlap_train')
DATA_PATH = os.path.join(WORKDIR, 'data/PDBBind_processed/')


In [14]:
args = parse_train_args('')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = SE3VerletFlow(device)
state_dict = torch.load(BEST_MODEL_PATH, map_location=torch.device('cpu'))

In [15]:
model.load_state_dict(state_dict)

<All keys matched successfully>

In [30]:
transform = InitializeVelocity()
train_dataset = PDBBind(cache_path=CACHE_PATH, split_path=SPLIT_TRAIN_PATH, keep_original=True,
                        num_conformers=1, root=DATA_PATH, c_alpha_max_neighbors=10, transform=transform)
loader_class = DataLoader
train_loader = loader_class(dataset=train_dataset, batch_size=1,
                            num_workers=1, shuffle=True, pin_memory=args.pin_memory)

loading data from memory:  /data/scratch/erives/verlet_flows/data/cache/limit0_INDEXtimesplit_no_lig_overlap_train_maxLigSizeNone_H1_recRad30_recMax10/heterographs.pkl
Number of complexes:  16360
radius protein: mean 35.422821044921875, std 11.115978240966797, max 140.3852081298828
radius molecule: mean 8.115903854370117, std 3.1345317363739014, max 29.449182510375977
distance protein-mol: mean 13.007867813110352, std 6.231020450592041, max 70.82417297363281
rmsd matching: mean 0.0, std 0.0, max 0


In [31]:
# check invertibility
model.check_invertible(next(iter(train_loader)))

AssertionError: 