In [None]:
import torch
import os, sys
import importlib
from argparse import ArgumentParser

sys.path.append('../')

from torch_geometric.loader import DataLoader

from utils.utils import get_model
from utils.parsing import parse_train_args
from model.flow import SE3VerletFlow
from model.coupling_layer import SE3CouplingLayer
from datasets.pdbbind import PDBBind, InitializeVelocity

WORKDIR = '/data/scratch/erives/verlet_flows/'

BEST_MODEL_PATH = os.path.join(WORKDIR, 'workdir/best_model.pt')
CACHE_PATH = os.path.join(WORKDIR, 'data/cache')
SPLIT_TRAIN_PATH = os.path.join(WORKDIR, 'data/splits/timesplit_no_lig_overlap_train')
DATA_PATH = os.path.join(WORKDIR, 'data/PDBBind_processed/')

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# state_dict = torch.load(BEST_MODEL_PATH, map_location=torch.device('cpu'))
# flow.load_state_dict(state_dict)

In [None]:
transform = InitializeVelocity()
train_dataset = PDBBind(cache_path=CACHE_PATH, split_path=SPLIT_TRAIN_PATH, keep_original=True,
                        num_conformers=1, root=DATA_PATH, c_alpha_max_neighbors=10, transform=transform)
loader_class = DataLoader
train_loader = loader_class(dataset=train_dataset, batch_size=1,
                            num_workers=1, shuffle=True)

In [None]:
# create model
flow = SE3VerletFlow(device)

# generate batch 
data = next(iter(train_loader))

# check invertibility
flow.check_invertible(data)

In [None]:
# create model
coupling_layer = SE3CouplingLayer(device=device)

# generate batch
data = next(iter(train_loader))

# check invertibility
coupling_layer.check_invertibility(data)