In [44]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import copy
import GPUtil
import tree
from openfold.utils import rigid_utils as ru
from scipy.spatial.transform import Rotation
from omegaconf import OmegaConf
import importlib
import logging
import sys
from collections import defaultdict
import matplotlib.cm as cm


from analysis import utils as au
from analysis import plotting
from data import utils as du
from data import se3_diffuser
from data import so3_diffuser
from data import r3_diffuser
from data import digs_data_loader
from model import loss
from experiments import train_se3_diffusion

# Enable logging
date_strftime_format = "%Y-%m-%y %H:%M:%S"
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(asctime)s %(message)s", datefmt=date_strftime_format)

In [45]:
# Reloads any code changes in the source file.
importlib.reload(digs_data_loader)
importlib.reload(se3_diffuser)
importlib.reload(so3_diffuser)
importlib.reload(r3_diffuser)
importlib.reload(du)
importlib.reload(plotting)
importlib.reload(train_se3_diffusion)

<module 'experiments.train_se3_diffusion' from '/data/rsg/chemistry/jyim/projects/protein_diffusion/experiments/train_se3_diffusion.py'>

In [46]:
# Load config.
conf = OmegaConf.load('../config/base.yaml')

# Redirect cache from notebook directory.
exp_conf = conf.experiment
exp_conf.data_location = 'rosetta'
exp_conf.ckpt_dir = None
exp_conf.num_loader_workers = 0
exp_conf.dist_mode = 'single'
exp_conf.use_wandb = False

# Data settings
data_conf = conf.data
# data_conf.rosetta.filtering.subset = 1

diff_conf = conf.diffuser
diff_conf.se3.r3.min_b = 4.0
diff_conf.se3.r3.max_b = 10.0

# Figure out data loading for PDB on rosetta
exp = train_se3_diffusion.Experiment(conf=conf)
train_loader, valid_loader = exp.create_rosetta_dataset(0, 1)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
exp._model = exp._model.to(device)

INFO: Using cached IGSO3.
INFO: Number of model parameters 15402858
INFO: Checkpoint not being saved.
INFO: Evaluation saved to: ./results/baseline/26D_10M_2022Y_13h_57m_08s
INFO: Training: 10172 examples
INFO: Validation: 40 examples with lengths [ 60  82 104 126 148 171 193 215 237 260]


In [47]:
so3_diff = exp.diffuser.se3_diffuser._so3_diffuser
train_loader, valid_loader = exp.create_rosetta_dataset(0, 1)
data_iter = iter(train_loader)
next_item = next(data_iter)
rigids_0 = next_item['rigids_0']
rigids_t = next_item['rigids_t']
gt_score_t = next_item['rot_score']
t = next_item['t']
print(t)

INFO: Training: 10172 examples
INFO: Validation: 40 examples with lengths [ 60  82 104 126 148 171 193 215 237 260]


tensor([0.0558, 0.1348, 0.2841, 0.0569, 0.6498, 0.5951, 0.6113, 0.4394, 0.7901,
        0.5515, 0.6680])


In [48]:
extract_rots = lambda r: ru.Rigid.from_tensor_7(r).get_rots()
rots_0 = extract_rots(rigids_0)
rots_t = extract_rots(rigids_t)

In [49]:
# R_t = R_0 * R_{0t}
# R_{0t} = inv(R_0) * R_t
rots_0_inv = rots_0.invert()
rots_0t = rots_0_inv.compose_r(rots_t)
# rots_0t = rots_t.compose_r(rots_0_inv)

In [50]:
quats_0t = rots_0t.get_quats()
mats_0t = rots_0t.get_rot_mats()
rotvec_0t = rots_0t.get_rotvec()

In [51]:
np_quats_0t = du.move_to_np(quats_0t)
np_mats_0t = du.move_to_np(mats_0t)
np_rotvec_0t = du.move_to_np(rotvec_0t)

### Sandbox

In [127]:
# Note that scipy switches the first and last indices.
# print(mats_0t[0, 0])
print(quats_0t[0, 0])
print(Rotation.from_matrix(mats_0t[0, 0]).as_quat())
print(Rotation.from_matrix(mats_0t[0, 0]).as_rotvec())

[ 1.0000000e+00  8.1956372e-08 -1.0430810e-07 -1.4901158e-08]
[ 8.19563719e-08 -1.04308110e-07 -1.49011585e-08  1.00000000e+00]
[ 1.63912744e-07 -2.08616219e-07 -2.98023171e-08]


In [None]:
# # Check that torch implementation matches scipy.
# quat = quats_0t[0, 0]

# # Formula for quat to rotvec
# angle = 2 * torch.atan2(
#     torch.linalg.norm(quat[1:]),
#     quat[:1]
# )
# angle2 = angle * angle
# scale = 2 + angle2 / 12 + 7 * angle2 * angle2 / 2880
# rotvec = scale * quat[1:]
# print(rotvec)

In [129]:
# Vectorized version
def quat_to_rotvec(quat, eps=1e-10):
    # w > 0 to ensure 0 <= angle <= pi
    flip = (quat[..., :1] < 0).float()
    quat = (-1 * quat) * flip + (1 - flip) * quat

    angle = 2 * torch.atan2(
        torch.linalg.norm(quat[..., 1:], dim=-1),
        quat[..., 0]
    )
    
    angle2 = angle * angle
    small_angle_scales = 2 + angle2 / 12 + 7 * angle2 * angle2 / 2880
    large_angle_scales = angle / torch.sin(angle / 2 + eps)
    
    small_angles = (angle <= 1e-3).float()
    rot_vec_scale = small_angle_scales * small_angles + (1 - small_angles) * large_angle_scales
    rot_vec = rot_vec_scale[..., None] * quat[..., 1:]
    return rot_vec

quats_0t = rots_0t.get_quats()
rotvec_0t = quat_to_rotvec(quats_0t)
rotvec_0t[0, 0]

tensor(False)


tensor([ 1.6391e-07, -2.0862e-07, -2.9802e-08])

In [130]:
# Use so3_diff to calculate score for rotvec_0t, check that it matches the original score
score_t = so3_diff.score(
    du.move_to_np(rotvec_0t)[0], 
    du.move_to_np(t)[0])
print(score_t[3])
print(next_item['rot_score'][0, 3])

In [None]:
so3_diff.discrete_omega

In [None]:
# vectoried write torch method for calculating the score
omega = torch.linalg.norm(rotvec_0t, dim=-1)
leading_dims = omega.shape
score_norms_t = so3_diff._score_norms[so3_diff.t_to_idx(du.move_to_np(t))]
omega_idx = torch.bucketize(omega, torch.tensor(so3_diff.discrete_omega))
omega_score_t = torch.gather(
    torch.tensor(score_norms_t), 1, omega_idx
)

### Test out the implemented functions

In [62]:
fixed_idx = torch.where(next_item['fixed_mask'])[1][:3]
diff_idx = torch.where(1 - next_item['fixed_mask'])[1][:3]

In [64]:
# Check that the numpy and pytorch conversions work.
print(Rotation.from_matrix(np_mats_0t[0, diff_idx]).as_rotvec())
print(np_rotvec_0t[0, diff_idx])
print(Rotation.from_matrix(np_mats_0t[0, fixed_idx]).as_rotvec())
print(np_rotvec_0t[0, fixed_idx])

[[ 0.23573646  0.16851528 -1.20901483]
 [ 0.15440957 -0.1131293   0.17023815]
 [ 0.64081511  0.20005367 -0.31668331]]
[[ 0.23573615  0.16851501 -1.2090131 ]
 [ 0.1544084  -0.11312843  0.17023686]
 [ 0.64081335  0.20005314 -0.31668246]]
[[-8.94069658e-08  2.32830640e-08 -3.72529024e-09]
 [ 8.75443220e-08  2.60770321e-08  2.98023224e-08]
 [-5.96046448e-08  2.32830644e-10  1.67638063e-08]]
[[-8.9406953e-08  2.3283063e-08 -3.7252899e-09]
 [ 8.7544315e-08  2.6077036e-08  2.9802326e-08]
 [-5.9604645e-08  2.3283064e-10  1.6763806e-08]]


In [63]:
# Check the score calculation is accurate
rot_score_t = so3_diff.torch_score(rotvec_0t, t)
print(rot_score_t[0, diff_idx])
print(gt_score_t[0, diff_idx])
print(rot_score_t[0, fixed_idx])
print(gt_score_t[0, fixed_idx])

tensor([[-0.6566, -0.4694,  3.3674],
        [-0.4325,  0.3169, -0.4769],
        [-1.7901, -0.5588,  0.8846]], dtype=torch.float64)
tensor([[-0.6561, -0.4690,  3.3652],
        [-0.4301,  0.3151, -0.4742],
        [-1.7846, -0.5571,  0.8819]], dtype=torch.float64)
tensor([[ 7.1620e-04, -1.8651e-04,  2.9842e-05],
        [-6.9897e-04, -2.0820e-04, -2.3795e-04],
        [ 4.9120e-04, -1.9188e-06, -1.3815e-04]], dtype=torch.float64)
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], dtype=torch.float64)
