In [44]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import copy
import GPUtil
from collections import defaultdict
from analysis import utils as au
from analysis import plotting
from data import utils as du
from data import se3_diffuser
from data import r3_diffuser
from data import so3_diffuser
import seaborn as sns
from model import loss
import tree
from model import reverse_se3_diffusion
from data import rosetta_data_loader
from data import digs_data_loader
from data import all_atom
from openfold.data import data_transforms
from experiments import train_se3_diffusion
from experiments import inference_se3_diffusion
from openfold.utils import rigid_utils as ru
from scipy.spatial.transform import Rotation
from openfold.np import residue_constants
from openfold.utils import feats
from scipy.spatial.transform import Rotation

from omegaconf import DictConfig, OmegaConf
import importlib

# Enable logging
import logging
import sys
date_strftime_format = "%Y-%m-%y %H:%M:%S"
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(asctime)s %(message)s", datefmt=date_strftime_format)

# Set environment variables for which GPUs to use.
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
chosen_gpu = ''.join(
    [str(x) for x in GPUtil.getAvailable(order='memory')])
os.environ["CUDA_VISIBLE_DEVICES"] = chosen_gpu
print(f"Using GPU: {chosen_gpu}")

Using GPU: 


In [20]:
importlib.reload(rosetta_data_loader)
importlib.reload(digs_data_loader)
importlib.reload(se3_diffuser)
importlib.reload(so3_diffuser)
importlib.reload(r3_diffuser)
importlib.reload(du)
importlib.reload(au)
importlib.reload(all_atom)
importlib.reload(loss)
importlib.reload(reverse_se3_diffusion)
importlib.reload(train_se3_diffusion)

<module 'experiments.train_se3_diffusion' from '/data/rsg/chemistry/jyim/projects/protein_diffusion/experiments/train_se3_diffusion.py'>

In [21]:
# Load config.
conf = OmegaConf.load('../config/base.yaml')

# Redirect cache from notebook directory.
exp_conf = conf.experiment
exp_conf.data_location = 'rosetta'
exp_conf.ckpt_dir = None
exp_conf.num_loader_workers = 1
exp_conf.dist_mode = 'single'
exp_conf.use_wandb = False
exp_conf.normalize_pred_score = True

data_conf = conf.data

diff_conf = conf.diffuser
diff_conf.diffuse_trans = True
diff_conf.diffuse_rot = True

diff_conf.rot_schedule = 'linear'
diff_conf.trans_schedule = 'exponential'

# print(OmegaConf.to_yaml(conf))

In [22]:
# Figure out data loading for PDB on rosetta
exp = train_se3_diffusion.Experiment(conf=conf)
train_loader, valid_loader = exp.create_rosetta_dataset(0, 1)

INFO: Number of model parameters 3685920
INFO: Using cached IGSO3.
INFO: Checkpoint not being saved.
INFO: Training: 18769 examples
INFO: Validation: 18769 examples


## Test all atom on real protein

In [53]:
train_iter = iter(train_loader)
next_item = next(train_iter)

In [54]:
# Extract features
b_idx = 0
batch_item = tree.map_structure(lambda x: x[b_idx], next_item)
res_mask = batch_item['res_mask'].bool()

gt_rigids = ru.Rigid.from_tensor_7(batch_item['rigids_0'][res_mask])
gt_atom37_pos = batch_item['atom37_pos'][res_mask]
gt_aatype = batch_item['aatype'][res_mask]
gt_atom37_mask = batch_item['atom37_mask'][res_mask].long()
gt_torsions = batch_item['torsion_angles_sin_cos'][res_mask]
gt_rigids = ru.Rigid.from_tensor_4x4(batch_item['rigidgroups_0'][res_mask])
gt_bb_rigids = ru.Rigid.from_tensor_7(batch_item['rigids_0'][res_mask])
gt_atom14_pos = batch_item['atom14_pos'][res_mask]
gt_bb_rigids = gt_bb_rigids.apply_trans_fn(lambda x: x * 10.0)


In [76]:
sub_torsion = torch.ones_like(gt_torsions)
sub_torsion = torch.tile(gt_torsions[:, 2][:, None, :], (1, 7, 1))

In [78]:
default_frames = torch.tensor(residue_constants.restype_rigid_group_default_frame)
pred_all_frames = feats.torsion_angles_to_frames(
    gt_bb_rigids,
    sub_torsion,
    gt_aatype,
    default_frames
)

In [79]:
# pred_all_frames = all_atom.torsion_angles_to_frames(
#     gt_bb_rigids,
#     gt_torsions,
#     gt_aatype,
# )

pred_atom14 = all_atom.frames_to_atom14_pos(
    pred_all_frames,
    gt_aatype,
)

In [80]:
pred_atom14[0]

tensor([[ -1.2260, -12.0409, -10.2621],
        [ -2.5045, -11.5818, -10.7942],
        [ -3.6368, -12.5411, -10.4386],
        [ -3.9657, -13.4363, -11.2188],
        [ -2.4203, -11.3989, -12.3111]])

In [66]:
pred_atom14[0, :5]

tensor([[ -1.2260, -12.0409, -10.2621],
        [ -2.5045, -11.5818, -10.7942],
        [ -3.6368, -12.5411, -10.4386],
        [ -3.9657, -13.4363, -11.2188],
        [ -2.4203, -11.3989, -12.3111]])

In [67]:
gt_atom14_pos[0, :5]

tensor([[ -1.1145, -11.8868, -10.3332],
        [ -2.5045, -11.5818, -10.7942],
        [ -3.6505, -12.5528, -10.4342],
        [ -4.0525, -13.4908, -11.1642],
        [ -2.4865, -11.3118, -12.2992]], dtype=torch.float64)

In [None]:
# Visualize aligned Ca structures
pred_2 = pred_atom14[:2]
gt_2 = gt_atom14_pos[:2]

all_traces = []
for i, n in enumerate(['N', 'CA', 'C', 'O', 'CB']):
    # if n in ['O']:        
    pred_bb_3d = plotting.create_scatter(
        pred_2[:, i], mode='markers', marker_size=3, opacity=1.0, name=f'pred: {n}')
    all_traces.append(pred_bb_3d)

    # if n != 'O':
    gt_bb_3d = plotting.create_scatter(
        gt_2[:, i], mode='markers', marker_size=3, opacity=1.0, name=f'GT: {n}')
    # else:
    #     gt_bb_3d = plotting.create_scatter(
    #         gt_2[:, i], mode='markers+lines', marker_size=3, opacity=1.0, name=f'GT: {n}')
    all_traces.append(gt_bb_3d)

plotting.plot_traces(all_traces)
# plotting.plt_3d(final_sample_ca_pos, ax, color='r', mode='line')
# plotting.plt_3d(gt_ca, ax, color='b', s=100, mode='scatter')
# plotting.plt_3d(gt_ca, ax, color='b', mode='line')

In [None]:
# bb_atom14 = ['N', 'CA', 'C', 'O', 'CB']
# bb_atom37 = ['N', 'CA', 'C', 'CB', 'O']

In [None]:

torsion_angles_feats = data_transforms.atom37_to_torsion_angles()(prot_feats)

pred_all_frames = all_atom.torsion_angles_to_frames(
    gt_rigids,
    gt_torsions,
    gt_aatype,
    all_atom.NULL_DEFAULT_FRAMES.to_tensor_4x4()
)

pred_atom14 = all_atom.frames_to_atom14_pos(
    pred_all_frames,
    gt_aatype,
)

In [42]:
pred_atom14[0, :5]

tensor([[-17.2163,  -2.8270,   2.6925],
        [-16.1672,  -3.8068,   2.9572],
        [-14.7886,  -3.1502,   2.9521],
        [-15.1279,  -5.8374,   3.9399],
        [-16.2156,  -4.9367,   1.9266]])

In [43]:
gt_atom14[0, :5]

tensor([[-17.5132,  -4.4478,   2.9622],
        [-16.1672,  -3.8068,   2.9572],
        [-16.2932,  -2.3168,   2.6762],
        [-17.1842,  -1.8818,   1.9472],
        [-15.2712,  -4.4498,   1.8942]], dtype=torch.float64)

In [44]:
# bb_atom14 = ['N', 'CA', 'C', 'O', 'CB']
# bb_atom37 = ['N', 'CA', 'C', 'CB', 'O']

In [28]:
# Visualize aligned Ca structures
pred_2 = pred_atom14[:2]
gt_2 = gt_atom14[:2]

all_traces = []
for i, n in enumerate(['N', 'CA', 'C', 'O', 'CB']):
    # if n in ['O']:        
    pred_bb_3d = plotting.create_scatter(
        pred_2[:, i], mode='markers', marker_size=3, opacity=1.0, name=f'pred: {n}')
    all_traces.append(pred_bb_3d)

    # if n != 'O':
    gt_bb_3d = plotting.create_scatter(
        gt_2[:, i], mode='markers', marker_size=3, opacity=1.0, name=f'GT: {n}')
    # else:
    #     gt_bb_3d = plotting.create_scatter(
    #         gt_2[:, i], mode='markers+lines', marker_size=3, opacity=1.0, name=f'GT: {n}')
    # all_traces.append(gt_bb_3d)

plotting.plot_traces(all_traces)
# plotting.plt_3d(final_sample_ca_pos, ax, color='r', mode='line')
# plotting.plt_3d(gt_ca, ax, color='b', s=100, mode='scatter')
# plotting.plt_3d(gt_ca, ax, color='b', mode='line')

In [247]:
# pred_o = du.move_to_np(pred_2[:, 3])
# gt_o = du.move_to_np(gt_2[:, 3])

pred_o = du.move_to_np(pred_2)[:, :5].reshape((-1, 3))
gt_o = du.move_to_np(gt_2)[:, :5].reshape((-1, 3))

In [248]:
aligned_pred_o, R, t, reflection = du.rigid_transform_3D(pred_o, gt_o)

det(R) < R, reflection detected!, correcting for it ...


In [249]:
Rotation.from_matrix(R).as_euler('xyz', degrees=True)

array([-9.99802933, 10.52790198, -9.39521421])

In [250]:
gt_bb_3d = plotting.create_scatter(aligned_pred_o, mode='markers+lines', marker_size=3, opacity=1.0, name='GT')
pred_bb_3d = plotting.create_scatter(gt_o, mode='markers+lines', marker_size=3, opacity=1.0, name='pred')

plotting.plot_traces([gt_bb_3d, pred_bb_3d])

In [186]:
global_pos_1 = all_atom.IDEAL_FRAME_TO_GLOBAL.invert_apply(all_atom.IDEALIZED_POS) * all_atom.NULL_ATOM_MASK
global_pos_2 = all_atom.IDEAL_BB_FRAMES[:, None].invert_apply(all_atom.IDEALIZED_POS) * all_atom.NULL_ATOM_MASK
global_pos_3 = all_atom.IDEAL_ROTS[:, None].invert_apply(all_atom.IDEALIZED_POS) * all_atom.NULL_ATOM_MASK
global_pos_4 = all_atom.IDEALIZED_POS

In [191]:
# Visualize aligned Ca structures

all_traces = []

pos = global_pos_2[1][None]
for i, n in enumerate(['N', 'CA', 'C', 'O', 'CB']):
    pred_bb_3d = plotting.create_scatter(
        pos[:,i], mode='markers+lines', marker_size=5, opacity=1.0, name=f'ideal: {n}')
    all_traces.append(pred_bb_3d)
    
# pos = global_pos_3[0][None]
# for i, n in enumerate(['N', 'CA', 'C', 'O', 'CB']):
#     pred_bb_3d = plotting.create_scatter(
#         pos[:,i], mode='markers+lines', marker_size=5, opacity=1.0, name=f'null: {n}')
#     all_traces.append(pred_bb_3d)

plotting.plot_traces(all_traces)
# plotting.plt_3d(final_sample_ca_pos, ax, color='r', mode='line')
# plotting.plt_3d(gt_ca, ax, color='b', s=100, mode='scatter')
# plotting.plt_3d(gt_ca, ax, color='b', mode='line')

## Test all atom can recover ideal positions

In [19]:
# Test we can recover the idealized positions
pred_all_frames = all_atom.torsion_angles_to_frames(
    all_atom.IDEAL_BB_FRAMES,
    all_atom.IDEAL_TORSION,
    all_atom.ALL_AATYPE,
)

pred_atom14 = all_atom.frames_to_atom14_pos(
    pred_all_frames,
    all_atom.ALL_AATYPE,
)


In [None]:
for i in range(21):
    print(i)
    print(du.move_to_np(pred_atom14[i]).round(3))
    print(du.move_to_np(all_atom.IDEALIZED_POS[i]).round(3))

In [None]:
# Visualize aligned Ca structures
fig = plt.figure(figsize=(10, 10))
ax = plt.axes(projection='3d')

plotting.plt_3d(final_sample_ca_pos, ax, color='r', s=100, mode='scatter')
plotting.plt_3d(final_sample_ca_pos, ax, color='r', mode='line')
# plotting.plt_3d(gt_ca, ax, color='b', s=100, mode='scatter')
# plotting.plt_3d(gt_ca, ax, color='b', mode='line')

## Test recover of real protein

In [46]:
train_iter = iter(train_loader)
next_item = next(train_iter)

In [111]:
# Extract features
b_idx = 0
batch_item = tree.map_structure(lambda x: x[b_idx], next_item)
res_mask = batch_item['res_mask'].bool()

gt_rigids = ru.Rigid.from_tensor_7(batch_item['rigids_0'][res_mask])
gt_atom_pos37 = batch_item['atom_positions'][res_mask]

gt_aatype = batch_item['aatype'][res_mask]
# gt_aatype = torch.zeros_like(gt_aatype)

gt_atom37_mask = batch_item['atom_mask'][res_mask].long()
# gt_atom37_mask = torch.tile(all_atom.IDEALIZED_POS37_MASK[0].double()[None], (gt_atom37_mask.shape[0], 1))

gt_rigids = gt_rigids.apply_trans_fn(
    lambda x: x*exp._data_conf.scale_factor)

gt_torsions = all_atom.prot_to_torsion_angles(
    gt_aatype, gt_atom_pos37, gt_atom37_mask
)[0]

In [112]:
prot_feats = data_transforms.make_atom14_masks({
    'aatype': gt_aatype.long(),
    'all_atom_positions': gt_atom_pos37,
    'all_atom_mask': gt_atom37_mask.double()
})
prot_feats = data_transforms.make_atom14_positions(prot_feats)
gt_atom14 = prot_feats['atom14_gt_positions']

In [113]:
gt_frames = all_atom.torsion_angles_to_frames(
    gt_rigids,
    gt_torsions,
    gt_aatype
)

gt_group_mask = torch.nn.functional.one_hot(
    all_atom.GROUP_IDX[gt_aatype, ...],
    num_classes=all_atom.DEFAULT_FRAMES[gt_aatype, ...].shape[-3],
)

gt_frame_to_global = (gt_frames[..., None, :] * gt_group_mask).map_tensor_fn(
    lambda x: torch.sum(x, dim=-1)
)

gt_atom_mask = all_atom.ATOM_MASK[gt_aatype, ...].unsqueeze(-1)
null_gt_pos = gt_frame_to_global.invert_apply(gt_atom14) * gt_atom_mask

In [114]:
ideal_null_gt_pos = all_atom.NULL_GLOBAL_POS[gt_aatype]

In [None]:
for i in range(5):
    print(f'gt: {i}\n', du.move_to_np(null_gt_pos[i]).round(3))
    print(f'ideal: {i}\n', du.move_to_np(ideal_null_gt_pos[i]).round(3))

## Test torsion of null_global

In [160]:
all_atom.NULL_GLOBAL_POS

torch.Size([21, 14, 3])

In [None]:
null_global_torsion = all_atom.prot_to_torsion_angles(
    gt_aatype, gt_atom_pos37, gt_atom37_mask
)[0]

In [None]:
# [21, 14, 1]
NULL_ATOM_MASK = ATOM_MASK[ALL_AATYPE, ...].unsqueeze(-1)
# [21, 14, 3]
NULL_GLOBAL_POS = IDEAL_FRAME_TO_GLOBAL.invert_apply(IDEALIZED_POS) * NULL_ATOM_MASK


In [None]:
# [21, 14, 8]
NULL_GROUP_MASK = torch.nn.functional.one_hot(
    GROUP_IDX[ALL_AATYPE, ...],
    num_classes=DEFAULT_FRAMES[ALL_AATYPE, ...].shape[-3],
)
# [21, 14]
IDEAL_FRAME_TO_GLOBAL = (IDEAL_FRAMES[..., None, :] * NULL_GROUP_MASK).map_tensor_fn(
    lambda x: torch.sum(x, dim=-1)
)
# [21, 14, 1]
NULL_ATOM_MASK = ATOM_MASK[ALL_AATYPE, ...].unsqueeze(-1)
# [21, 14, 3]
NULL_GLOBAL_POS = IDEAL_FRAME_TO_GLOBAL.invert_apply(IDEALIZED_POS) * NULL_ATOM_MASK


In [None]:
IDEAL_FRAMES = torsion_angles_to_frames(
    IDEAL_BB_FRAMES,
    IDEAL_TORSION,
    ALL_AATYPE)

In [None]:
# [21, 14, 4, 4]
NULL_DEFAULT_4X4 = DEFAULT_FRAMES[ALL_AATYPE, ...]
# [21, 14]
NULL_GROUP_MASK = GROUP_IDX[ALL_AATYPE, ...]
# [21, 14, 8]
NULL_GROUP_MASK = torch.nn.functional.one_hot(
    NULL_GROUP_MASK,
    num_classes=NULL_DEFAULT_4X4.shape[-3],
)

In [None]:
# [21, 14]
IDEAL_FRAME_TO_GLOBAL = (IDEAL_FRAMES[..., None, :] * NULL_GROUP_MASK).map_tensor_fn(
    lambda x: torch.sum(x, dim=-1)
)
# [21, 14, 1]
NULL_ATOM_MASK = ATOM_MASK[ALL_AATYPE, ...].unsqueeze(-1)
# [21, 14, 3]
NULL_GLOBAL_POS = IDEAL_FRAME_TO_GLOBAL.invert_apply(IDEALIZED_POS) * NULL_ATOM_MASK
