In [248]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import copy
import GPUtil
from collections import defaultdict
from analysis import utils as au
from analysis import plotting
from data import utils as du
from data import se3_diffuser
from data import r3_diffuser
from data import so3_diffuser
import seaborn as sns
from model import loss
from model import reverse_se3_diffusion
import tree
from data import rosetta_data_loader
from data import digs_data_loader
from data import all_atom
from experiments import train_se3_diffusion
from experiments import inference_se3_diffusion
from openfold.utils import rigid_utils as ru
from openfold.np import residue_constants
from scipy.spatial.transform import Rotation

from omegaconf import DictConfig, OmegaConf
import importlib

# Enable logging
import logging
import sys
date_strftime_format = "%Y-%m-%y %H:%M:%S"
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(asctime)s %(message)s", datefmt=date_strftime_format)

In [249]:
importlib.reload(rosetta_data_loader)
importlib.reload(digs_data_loader)
importlib.reload(se3_diffuser)
importlib.reload(so3_diffuser)
importlib.reload(r3_diffuser)
importlib.reload(du)
importlib.reload(au)
importlib.reload(all_atom)
importlib.reload(plotting)
importlib.reload(loss)
importlib.reload(reverse_se3_diffusion)
importlib.reload(inference_se3_diffusion)
importlib.reload(train_se3_diffusion)

<module 'experiments.train_se3_diffusion' from '/data/rsg/chemistry/jyim/projects/protein_diffusion/experiments/train_se3_diffusion.py'>

In [250]:
# Load model
inference_conf_path = '../config/inference.yaml'
inference_conf = OmegaConf.load(inference_conf_path)
# inference_conf.ckpt_dir = '../pkl_jar/ckpt/fixed_scaling/28D_09M_2022Y_12h_03m_45s'
inference_conf.ckpt_dir = '../pkl_jar/ckpt/inpainting_0/14D_10M_2022Y_00h_00m_47s'

inference_conf.default_conf_path = '../config/base.yaml'
inference_conf.output_dir = '../results'

# print(OmegaConf.to_yaml(inference_conf))

# Set up sampler
sampler = inference_se3_diffusion.Sampler(inference_conf)
train_loader, valid_loader = sampler.exp.create_rosetta_dataset(0, 1)
train_csv = train_loader.dataset.csv

INFO: Loading checkpoint from ../pkl_jar/ckpt/inpainting_0/14D_10M_2022Y_00h_00m_47s/step_210000.pth
INFO: Saving results to ../results/14D_10M_2022Y_00h_00m_47s
INFO: Number of model parameters 10005987
INFO: Using cached IGSO3.
INFO: Checkpoints saved to: ./pkl_jar/ckpt/inpainting_0/14D_10M_2022Y_00h_00m_47s/inpainting_0/14D_10M_2022Y_14h_56m_20s
INFO: Evaluation saved to: ./results/inpainting_0/14D_10M_2022Y_00h_00m_47s/inpainting_0/14D_10M_2022Y_14h_56m_20s
INFO: Training: 3798 examples
INFO: Validation: 40 examples with lengths [ 60  67  75  82  90  97 105 112 120 128]


### Sample using data

In [251]:
# Sample an example

# data_iter = iter(train_loader)
# next_item = next(data_iter)

data_iter = iter(valid_loader)
next_item, pdb_names = next(data_iter)

In [252]:
# Run sampler
save = True
add_noise = True
b_idx = 0
pdb_name = pdb_names[b_idx]
print(pdb_name)
res_mask = next_item['res_mask'][b_idx]
fixed_mask = next_item['fixed_mask'][b_idx]
aatype_impute = next_item['aatype_0'][b_idx]
psi_impute = next_item['torsion_angles_sin_cos'][b_idx]
rigid_impute = ru.Rigid.from_tensor_7(next_item['rigids_0'][b_idx])

rigids_traj, aatype_traj, pdb_path, traj_path, fasta_path = sampler.sample(
    res_mask=res_mask,
    fixed_mask=fixed_mask,
    aatype_impute=aatype_impute,
    rigid_impute=rigid_impute,
    psi_impute=psi_impute,
    save=save,
    add_noise=add_noise,
    file_prefix='./samples/'
)

5n35


  trans=torch.tensor(trans))
INFO: Saved sample to ./samples/len_60_diffuse_083_1.pdb
INFO: Saved trajectory to ./samples/len_60_diffuse_083_traj_1.pdb
INFO: Saved sequence to ./samples/len_60_diffuse_083_of.fasta


In [254]:
sc, plddt = sampler.run_self_consistency(pdb_path, fasta_path)

INFO:root:Loading weights from /data/rsg/chemistry/jyim/third_party/omegafold/release1.pt
INFO:root:Constructing OmegaFold
INFO:root:Reading ./samples/len_60_diffuse_083_of.fasta
INFO:root:Predicting 1th chain in ./samples/len_60_diffuse_083_of.fasta
INFO:root:60 residues in this chain.
INFO:root:Finished prediction in 3.49 seconds.
INFO:root:Saving prediction to ./samples/len_60_diffuse_083_of.pdb
INFO:root:Saved
INFO:root:Done!


In [258]:
plddt[:, 

array([45.06, 45.06, 45.06, 45.06, 45.06,  0.  , 45.06, 45.06,  0.  ,
        0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
        0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
        0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
        0.  ])

In [256]:
sc

0.3981160816368821

In [226]:
unpad_gt_aatype = aatype_impute[res_mask.bool()]
pred_aatype_prob = aatype_traj[0][res_mask.bool()]
pred_aatype = np.argmax(pred_aatype_prob[:, :-1], axis=-1)
diffuse_mask = du.move_to_np(1 - fixed_mask[res_mask.bool()]).astype(bool)

gt_seq = du.aatype_to_seq(du.move_to_np(torch.argmax(unpad_gt_aatype, dim=-1)))
pred_seq = du.aatype_to_seq(pred_aatype)


In [227]:
match = 0
for i,(pred, gt) in enumerate(zip(pred_seq, gt_seq)):
    if diffuse_mask[i]:
        # print(i, pred, gt)
        match += int(pred == gt)
print(match / np.sum(diffuse_mask))

0.44660194174757284


In [228]:
import subprocess
output_dir = os.path.dirname(fasta_path)
process = subprocess.Popen([
    'omegafold',
    fasta_path,
    output_dir,
])

INFO:root:Loading weights from /data/rsg/chemistry/jyim/third_party/omegafold/release1.pt
INFO:root:Constructing OmegaFold
INFO:root:Reading ./samples/len_112_diffuse_092_of.fasta
INFO:root:Predicting 1th chain in ./samples/len_112_diffuse_092_of.fasta
INFO:root:112 residues in this chain.
INFO:root:Finished prediction in 5.82 seconds.
INFO:root:Saving prediction to ./samples/len_112_diffuse_092_of.pdb
INFO:root:Saved
INFO:root:Done!


In [217]:
from tmtools import tm_align

In [229]:
def calc_tm_score(pos_1, pos_2, seq_1, seq_2):
    tm_results = tm_align(pos_1, pos_2, seq_1, seq_2)
    return tm_results.tm_norm_chain1, tm_results.tm_norm_chain2 

In [233]:
sample_feats = du.parse_pdb_feats('sample', pdb_path)
of_feats = du.parse_pdb_feats('sample', fasta_path.replace('.fasta', '.pdb'))

In [241]:
of_feats

dict_keys(['atom_positions', 'aatype', 'atom_mask', 'residue_index', 'bb_mask', 'bb_positions'])

In [237]:
sample_seq = du.aatype_to_seq(sample_feats['aatype'])

In [238]:
sample_seq

'AAELGHLKECLGNLKENLYASHWSAYYQFYEPVDAAGVGLHDIHDIYKHPMDLEKMKRKMENRDYTAAAFAAFVRLMFFNCYAKYNPPDHPVYAMAQKVRLVFAAYLADYDE'

In [239]:
_, tm_score = calc_tm_score(
    sample_feats['bb_positions'], of_feats['bb_positions'], sample_seq, sample_seq)

In [240]:
tm_score

0.5041145979696957

### Sample without data

In [None]:
# Run sampler
batch_size = 4
num_res = 150
res_mask = torch.ones((batch_size, num_res))
save = True
samples_traj = sampler.sample(
    res_mask=res_mask,
    save=save,
    file_prefix='./samples/'
)