In [13]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import copy
import GPUtil
from collections import defaultdict
from analysis import utils as au
from analysis import plotting
from data import utils as du
from data import se3_diffuser
from data import r3_diffuser
from data import so3_diffuser
import seaborn as sns
from model import loss
from model import reverse_se3_diffusion
import tree
from data import rosetta_data_loader
from data import digs_data_loader
from data import all_atom
from experiments import train_se3_diffusion
from experiments import inference_se3_diffusion
from openfold.utils import rigid_utils as ru
from openfold.np import residue_constants
from scipy.spatial.transform import Rotation

from omegaconf import DictConfig, OmegaConf
import importlib

# Enable logging
import logging
import sys
date_strftime_format = "%Y-%m-%y %H:%M:%S"
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(asctime)s %(message)s", datefmt=date_strftime_format)

In [14]:
importlib.reload(rosetta_data_loader)
importlib.reload(digs_data_loader)
importlib.reload(se3_diffuser)
importlib.reload(so3_diffuser)
importlib.reload(r3_diffuser)
importlib.reload(du)
importlib.reload(au)
importlib.reload(all_atom)
importlib.reload(plotting)
importlib.reload(loss)
importlib.reload(reverse_se3_diffusion)
importlib.reload(inference_se3_diffusion)
importlib.reload(train_se3_diffusion)

<module 'experiments.train_se3_diffusion' from '/data/rsg/chemistry/jyim/projects/protein_diffusion/experiments/train_se3_diffusion.py'>

In [15]:
# Load model
inference_conf_path = '../config/inference.yaml'
inference_conf = OmegaConf.load(inference_conf_path)
# inference_conf.ckpt_dir = '../pkl_jar/ckpt/fixed_scaling/28D_09M_2022Y_12h_03m_45s'
inference_conf.ckpt_dir = '../pkl_jar/ckpt/no_aatype_0/02D_10M_2022Y_08h_35m_44s'

inference_conf.default_conf_path = '../config/base.yaml'
inference_conf.output_dir = '../results'

# print(OmegaConf.to_yaml(inference_conf))

# Set up sampler
sampler = inference_se3_diffusion.Sampler(inference_conf)
train_loader, valid_loader = sampler.exp.create_rosetta_dataset(0, 1)
train_csv = train_loader.dataset.csv

INFO: Loading checkpoint from ../pkl_jar/ckpt/no_aatype_0/02D_10M_2022Y_08h_35m_44s/step_550000.pth
INFO: Saving results to ../results/02D_10M_2022Y_08h_35m_44s
INFO: Number of model parameters 17486935
INFO: Using cached IGSO3.
INFO: Checkpoints saved to: ./pkl_jar/ckpt/no_aatype_0/02D_10M_2022Y_08h_35m_44s/no_aatype_0/03D_10M_2022Y_09h_56m_56s
INFO: Evaluation saved to: ./results/no_aatype_0/02D_10M_2022Y_08h_35m_44s/no_aatype_0/03D_10M_2022Y_09h_56m_56s
INFO: Training: 3798 examples
INFO: Validation: 40 examples with lengths [ 60  67  75  82  90  97 105 112 120 128]


### Sample using data

In [16]:
# Sample an example

# data_iter = iter(train_loader)
# next_item = next(data_iter)

data_iter = iter(valid_loader)
next_item, _ = next(data_iter)

In [36]:
next_item['res_idx'][:batch_size]

tensor([[  3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,  15,  16,
          17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,  28,  29,  30,
          31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,
          45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,
          59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,  72,
          73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  84,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357,
         358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371,
         372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385,
         386, 387, 388, 389, 390

In [35]:
# Run sampler
save = True
add_noise = False
batch_size = 4
res_mask = next_item['res_mask'][:batch_size]
aatype = next_item['aatype'][:batch_size] * 0
# res_idx = next_item['res_idx'][:batch_size]
res_idx = res_idx - torch.min(res_idx + (res_idx == 0).long() * 1000, dim=-1).values[..., None]
samples_traj = sampler.sample(
    res_mask=res_mask,
    aatype=aatype,
    save=save,
    res_idx=res_idx,
    add_noise=add_noise,
    file_prefix='./samples/'
)

INFO: Saved sample to ./samples/len_82_2.pdb
INFO: Saved trajectory to ./samples/len_82_traj_2.pdb
INFO: Saved sample to ./samples/len_112_2.pdb
INFO: Saved trajectory to ./samples/len_112_traj_2.pdb
INFO: Saved sample to ./samples/len_105_2.pdb
INFO: Saved trajectory to ./samples/len_105_traj_2.pdb
INFO: Saved sample to ./samples/len_60_2.pdb
INFO: Saved trajectory to ./samples/len_60_traj_2.pdb


### Sample without data

In [None]:
# Run sampler
batch_size = 4
num_res = 150
res_mask = torch.ones((batch_size, num_res))
save = True
samples_traj = sampler.sample(
    res_mask=res_mask,
    save=save,
    file_prefix='./samples/'
)