In [68]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import copy
import GPUtil
from collections import defaultdict
from analysis import utils as au
from analysis import plotting
from data import utils as du
from data import se3_diffuser
from data import r3_diffuser
from data import so3_diffuser
import seaborn as sns
from model import loss
from model import reverse_se3_diffusion
import tree
from data import rosetta_data_loader
from data import digs_data_loader
from data import all_atom
from experiments import train_se3_diffusion
from experiments import inference_se3_diffusion
from openfold.utils import rigid_utils as ru
from openfold.np import residue_constants
from scipy.spatial.transform import Rotation

from omegaconf import DictConfig, OmegaConf
import importlib

# Enable logging
import logging
import sys
date_strftime_format = "%Y-%m-%y %H:%M:%S"
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(asctime)s %(message)s", datefmt=date_strftime_format)

In [69]:
importlib.reload(rosetta_data_loader)
importlib.reload(digs_data_loader)
importlib.reload(se3_diffuser)
importlib.reload(so3_diffuser)
importlib.reload(r3_diffuser)
importlib.reload(du)
importlib.reload(au)
importlib.reload(all_atom)
importlib.reload(plotting)
importlib.reload(loss)
importlib.reload(reverse_se3_diffusion)
importlib.reload(inference_se3_diffusion)
importlib.reload(train_se3_diffusion)

<module 'experiments.train_se3_diffusion' from '/data/rsg/chemistry/jyim/projects/protein_diffusion/experiments/train_se3_diffusion.py'>

In [70]:
# Load model
inference_conf_path = '../config/inference.yaml'
inference_conf = OmegaConf.load(inference_conf_path)
# inference_conf.ckpt_dir = '../pkl_jar/ckpt/exp_lin_with_aatype_0/11D_09M_2022Y_13h_12m_28s'
inference_conf.ckpt_dir = '../pkl_jar/ckpt/with_aatype_0/11D_09M_2022Y_11h_44m_55s'
# inference_conf.ckpt_dir = '../pkl_jar/ckpt/baseline_0/11D_09M_2022Y_13h_06m_45s'
inference_conf.default_conf_path = '../config/base.yaml'
inference_conf.output_dir = '../results'

# print(OmegaConf.to_yaml(inference_conf))

# Set up sampler
sampler = inference_se3_diffusion.Sampler(inference_conf)
train_loader, valid_loader = sampler.exp.create_rosetta_dataset(0, 1)
train_csv = train_loader.dataset.csv

INFO: Loading checkpoint from ../pkl_jar/ckpt/with_aatype_0/11D_09M_2022Y_11h_44m_55s/step_470000.pth
INFO: Saving results to ../results/11D_09M_2022Y_11h_44m_55s
INFO: Number of model parameters 3703648
INFO: Using cached IGSO3.
INFO: Checkpoints saved to: ./pkl_jar/ckpt/with_aatype_0/11D_09M_2022Y_11h_44m_55s/with_aatype_0/13D_09M_2022Y_21h_57m_10s
INFO: Evaluation saved to: ./results/with_aatype_0/11D_09M_2022Y_11h_44m_55s/with_aatype_0/13D_09M_2022Y_21h_57m_10s
INFO: Training: 1000 examples
INFO: Validation: 40 examples with lengths [ 60  90 113 134 156 177 198 218 239 260]


In [71]:
train_loader, valid_loader = sampler.exp.create_rosetta_dataset(0, 1)
train_csv = train_loader.dataset.csv

INFO: Training: 1000 examples
INFO: Validation: 40 examples with lengths [ 60  90 113 134 156 177 198 218 239 260]


### Sample using data

In [72]:
# Sample an example

# data_iter = iter(train_loader)
# next_item = next(data_iter)

data_iter = iter(valid_loader)
next_item, _ = next(data_iter)

In [73]:
# Run sampler
save = True
batch_size = 4
res_mask = next_item['res_mask'][:batch_size]
aatype = next_item['aatype'][:batch_size]
res_idx = next_item['res_idx'][:batch_size]
samples_traj = sampler.sample(
    res_mask=res_mask,
    aatype=aatype,
    save=save,
    res_idx=res_idx,
    file_prefix='./samples/'
)

INFO: Saved sample to ./samples/len_90_1.pdb
INFO: Saved trajectory to ./samples/len_90_traj_1.pdb
INFO: Saved sample to ./samples/len_177_1.pdb
INFO: Saved trajectory to ./samples/len_177_traj_1.pdb
INFO: Saved sample to ./samples/len_217_1.pdb
INFO: Saved trajectory to ./samples/len_217_traj_1.pdb
INFO: Saved sample to ./samples/len_134_1.pdb
INFO: Saved trajectory to ./samples/len_134_traj_1.pdb


In [74]:
from Bio import PDB

In [75]:
parser = PDB.PDBParser(QUIET=True)

pdb_path = './samples/len_90_1.pdb'
pdb_name = 'test'
structure = parser.get_structure(pdb_name, pdb_path)

struct_chains = {
    chain.id: chain
    for chain in structure.get_chains() if chain.id == 'A'}
# TODO: Add logic for handling multiple chains.
assert len(struct_chains) == 1

# chain_prot = process_chain(struct_chains['A'], 'A')

In [None]:
sample_metrics = metrics.protein_metrics(saved_path, unpad_prot)

### Sample without data

In [63]:
# Run sampler
batch_size = 4
num_res = 150
res_mask = torch.ones((batch_size, num_res))
save = True
samples_traj = sampler.sample(
    res_mask=res_mask,
    save=save,
    file_prefix='./samples/'
)

INFO: Saved sample to ./samples/len_150_1.pdb
INFO: Saved trajectory to ./samples/len_150_traj_1.pdb
INFO: Saved sample to ./samples/len_150_2.pdb
INFO: Saved trajectory to ./samples/len_150_traj_2.pdb
INFO: Saved sample to ./samples/len_150_3.pdb
INFO: Saved trajectory to ./samples/len_150_traj_3.pdb
INFO: Saved sample to ./samples/len_150_4.pdb
INFO: Saved trajectory to ./samples/len_150_traj_4.pdb
