In [2]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import copy
import GPUtil
from collections import defaultdict
from analysis import utils as au
from analysis import plotting
from data import utils as du
from data import se3_diffuser
from data import r3_diffuser
from data import so3_diffuser
from model import loss
from model import reverse_se3_diffusion
import seaborn as sns
import omegafold as of
import random

import tree
import sympy as sym
from data import rosetta_data_loader
from data import digs_data_loader
from experiments import train_se3_diffusion
from experiments import inference_se3_diffusion
from openfold.utils import rigid_utils as ru
from data import all_atom
from scipy.spatial.transform import Rotation
from model import basis_utils
from scipy.special import gamma
import pandas as pd

from omegaconf import OmegaConf
import importlib

Rigid = ru.Rigid
Rotation = ru.Rotation

# Enable logging
import logging
import sys
date_strftime_format = "%Y-%m-%y %H:%M:%S"
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(asctime)s %(message)s", datefmt=date_strftime_format)

INFO: Using numpy backend
INFO: Note: detected 80 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
INFO: Note: NumExpr detected 80 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO: NumExpr defaulting to 8 threads.


In [3]:
from scipy.spatial.transform import Rotation

In [4]:
# Reloads any code changes to 
importlib.reload(rosetta_data_loader)
importlib.reload(digs_data_loader)
importlib.reload(se3_diffuser)
importlib.reload(so3_diffuser)
importlib.reload(r3_diffuser)
importlib.reload(du)
importlib.reload(reverse_se3_diffusion)
importlib.reload(train_se3_diffusion)

<module 'experiments.train_se3_diffusion' from '/data/rsg/chemistry/jyim/projects/protein_diffusion/experiments/train_se3_diffusion.py'>

In [5]:
# Load config.
conf = OmegaConf.load('../config/base.yaml')

# Redirect cache from notebook directory.
exp_conf = conf.experiment
exp_conf.data_location = 'rosetta'
exp_conf.ckpt_dir = None
exp_conf.num_loader_workers = 0
exp_conf.dist_mode = 'single'
exp_conf.use_wandb = False

# Data settings
data_conf = conf.data
data_conf.rosetta.filtering.subset = 1
data_conf.rosetta.filtering.max_len = 80

# Diffusion settings
diff_conf = conf.diffuser
diff_conf.diffuse_trans = True  # whether to diffuse translations
diff_conf.diffuse_rot = True  # whether to diffuse rotations
# Noise schedules
diff_conf.rot_schedule = 'linear'
diff_conf.trans_schedule = 'exponential'

diff_conf.trans_align_t = True

# print(OmegaConf.to_yaml(conf))

### Load data

In [6]:
# Figure out data loading for PDB on rosetta
exp = train_se3_diffusion.Experiment(conf=conf)
train_loader, valid_loader = exp.create_rosetta_dataset(0, 1)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
exp._model = exp._model.to(device)

INFO: Number of model parameters 14353879
INFO: Using cached IGSO3.
INFO: Checkpoint not being saved.
INFO: Evaluation saved to: ./results/baseline/04D_10M_2022Y_10h_51m_32s
INFO: Training: 1 examples
INFO: Validation: 4 examples with lengths [64 64 64 64 64 64 64 64 64 64]


In [7]:
train_iter = iter(train_loader)
next_item = next(train_iter)
# next_item = tree.map_structure(lambda x: x[0], next_item)

## Cath

In [8]:
cath_csv = '/data/rsg/chemistry/jyim/large_data/cath/metadata.csv'

In [9]:
csv_df = pd.read_csv(cath_csv)

In [11]:
csv_df.head()

Unnamed: 0,chain_name,cath_code,cath_split,processed_path,raw_path,oligomeric_count,oligomeric_detail,resolution,structure_method,num_chains,seq_len,modeled_seq_len,coil_percent,helix_percent,strand_percent,radius_gyration
0,3fkf.A,['3.40.30'],test,/data/rsg/chemistry/jyim/large_data/cath/fk/3f...,/data/rsg/chemistry/jyim/large_data/pdb/30_08_...,4.0,tetrameric,2.2,x-ray diffraction,4,220,141,0.446809,0.319149,0.234043,[1.41676082]
1,2d9e.A,['1.20.920'],test,/data/rsg/chemistry/jyim/large_data/cath/d9/2d...,/data/rsg/chemistry/jyim/large_data/pdb/30_08_...,,,0.0,solution nmr,1,121,121,0.380165,0.603306,0.016529,[1.53328454]
2,2lkl.A,['1.10.1900'],test,/data/rsg/chemistry/jyim/large_data/cath/lk/2l...,/data/rsg/chemistry/jyim/large_data/pdb/30_08_...,1.0,monomeric,0.0,solution nmr,1,81,81,0.345679,0.654321,0.0,[1.33674072]
3,1ud9.A,['3.70.10'],test,/data/rsg/chemistry/jyim/large_data/cath/ud/1u...,/data/rsg/chemistry/jyim/large_data/pdb/30_08_...,11113.0,"monomeric,monomeric,monomeric,monomeric,trimeric",1.68,x-ray diffraction,4,333,242,0.322314,0.235537,0.442149,[1.91501391]
4,2rem.B,['3.40.30'],test,/data/rsg/chemistry/jyim/large_data/cath/re/2r...,/data/rsg/chemistry/jyim/large_data/pdb/30_08_...,112.0,"monomeric,monomeric,dimeric",1.9,x-ray diffraction,4,317,187,0.363636,0.502674,0.13369,[1.64780367]


In [None]:
monomer_df = csv_df[csv_df['oligomeric_detail'].map(lambda x: True if 'monomeric' in x else False)]

In [None]:
monomer_df.shape

## Omegafold

In [8]:
aatype = next_item['aatype']
res_mask = next_item['res_mask'].long()
seq = du.aatype_to_seq(du.move_to_np(aatype[0][torch.where(res_mask[0])]))

In [9]:
fasta_seqs = {
    'test': seq
}

In [10]:
with open('test.fasta', 'w') as f:
    for k,v in fasta_seqs.items():
        f.write(f'> {k}\n{v}')

## Spatial cropping

In [8]:
bb_pos = next_item['atom37_pos'][0, :60, 1]

In [9]:
dist2d = torch.linalg.norm(bb_pos[:, None, :] - bb_pos[None, :, :], dim=-1)

In [11]:
# Randomly select residue then sample a distance cutoff
crop_seed = random.randrange(dist2d.shape[0])
seed_dists = dist2d[crop_seed]
max_dist = torch.max(seed_dists)
min_dist = torch.min(seed_dists)
dist_cutoff = (max_dist - min_dist)*random.uniform(0.05, 0.95)

In [13]:
condition_mask = seed_dists <= dist_cutoff

In [18]:
condition_mask

tensor([False, False, False, False,  True, False, False, False, False, False,
        False, False,  True,  True,  True,  True,  True,  True,  True, False,
        False,  True, False, False, False, False, False, False, False, False,
        False, False, False, False,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True, False, False, False,
        False, False, False, False, False, False, False, False, False, False])

In [17]:
condition_mask[43]

tensor(True)

In [None]:
args, state_dict, forward_config = pipeline.get_args()
# create the output directory
os.makedirs(args.output_dir, exist_ok=True)
# get the model
logging.info(f"Constructing OmegaFold")
model = of.OmegaFold(of.make_config())
if state_dict is None:
    logging.warning("Inferencing without loading weight")
else:
    if "model" in state_dict:
        state_dict = state_dict.pop("model")
    model.load_state_dict(state_dict)
model.eval()
model.to(args.device)

In [None]:
unpad_gt_ca_pos = gt_atom37_pos[..., CA_IDX, :][bb_mask]
seq = du.aatype_to_seq(gt_aatype)
_, tm_score = calc_tm_score(ca_pos, unpad_gt_ca_pos, seq, seq)