In [15]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import copy

from collections import defaultdict
from analysis import utils as au
from analysis import plotting
from data import utils as du
from data import se3_diffuser
from data import r3_diffuser
from data import so3_diffuser
from model import loss
from model import reverse_se3_diffusion
import tree
from data import rosetta_data_loader
from data import digs_data_loader
from experiments import train_se3_diffusion
from openfold.utils import rigid_utils as ru
from scipy.spatial.transform import Rotation

from omegaconf import OmegaConf
import importlib

# Enable logging
import logging
import sys
date_strftime_format = "%Y-%m-%y %H:%M:%S"
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(asctime)s %(message)s", datefmt=date_strftime_format)

In [16]:
importlib.reload(digs_data_loader)
importlib.reload(se3_diffuser)
importlib.reload(so3_diffuser)
importlib.reload(r3_diffuser)
importlib.reload(du)
importlib.reload(train_se3_diffusion)

<module 'experiments.train_se3_diffusion' from '/home/jyim/Projects/protein_diffusion_v2/experiments/train_se3_diffusion.py'>

In [17]:
# Load config.
conf = OmegaConf.load('../config/base.yaml')

# Redirect cache from notebook directory.
exp_conf = conf.experiment
exp_conf.data_location = 'digs'

data_conf = conf.data
data_conf.digs.filtering.subset = 1
data_conf.digs.fraction_fb = 0.0

diff_conf = conf.diffuser
diff_conf.rot_schedule = 'logarithmic'
diff_conf.trans_schedule = 'linear'

In [18]:
exp = train_se3_diffusion.Experiment(conf=conf)
train_loader, train_sampler, valid_loader, valid_sampler = exp.create_digs_dataset(0, 1)
se3_diff = exp.diffuser
r3_diff = se3_diff._r3_diffuser
so3_diff = se3_diff._so3_diffuser

# SDE time parameters
num_t = 100
forward_t = np.linspace(1e-3, 1, num_t)
reverse_t = forward_t[::-1]
dt = 1 / num_t

INFO: Number of model parameters 3488030
INFO: Using cached IGSO3.
INFO: Loaded data at ./pkl_jar/dataset_5.0_260_60_80.0_100_2020-Apr-30_1.pkl
INFO: Loaded data at ./pkl_jar/dataset_5.0_260_60_80.0_100_2020-Apr-30_1.pkl


In [19]:
pdb_name = 'len_125_sample_2_sctm_sctm_90'
pdb_data = du.parse_pdb(
    f'{pdb_name}.pdb'
)
ca_pos = pdb_data[0][:, 1] / 10.0

In [30]:
beta_schedule = np.linspace(0.0001, 0.02, num_t)

def forward(
        x_t: np.ndarray,
        t: int):
    beta_t = beta_schedule[t]
    z = np.random.normal(size=x_t.shape)
    return x_t * np.sqrt(1 - beta_t) + z * np.sqrt(beta_t)


In [31]:
# Forward diffusion
x_0 = np.copy(ca_pos)
forward_diffusion = [ca_pos]
# for t in forward_t:
for t in range(num_t):
    forward_diffusion.append(
        forward(np.copy(forward_diffusion[-1]), t))
#         r3_diff.forward_marginal(forward_diffusion[-1], t)[0]
#         r3_diff.forward(np.copy(forward_diffusion[-1]), t)
forward_diffusion = np.stack(forward_diffusion)
reverse_diffusion = np.flip(forward_diffusion, 0)

In [43]:
ca_traj = reverse_diffusion * 10.0
final_state = np.tile(ca_traj[-1][None], (80, 1, 1))
save_traj = np.concatenate([ca_traj, final_state], 0)

In [44]:
# Save final design as PDB.
file_path = f'./{pdb_name}_traj.pdb'
save_path = au.write_prot_to_pdb(
    save_traj,
    file_path
)
print(f'Written to {save_path}')

Written to ./len_125_sample_2_sctm_sctm_90_traj_8.pdb
