In [41]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import copy
import GPUtil
from collections import defaultdict
from analysis import utils as au
from analysis import plotting
from data import utils as du
from data import se3_diffuser
from data import r3_diffuser
from data import so3_diffuser
from model import loss
from model import reverse_se3_diffusion

import tree
import sympy as sym
from data import rosetta_data_loader
from data import digs_data_loader
from experiments import train_se3_diffusion
from experiments import inference_se3_diffusion
from openfold.utils import rigid_utils as ru
from data import all_atom
from scipy.spatial.transform import Rotation
from model import basis_utils

from omegaconf import OmegaConf
import importlib

Rigid = ru.Rigid
Rotation = ru.Rotation

# Enable logging
import logging
import sys
date_strftime_format = "%Y-%m-%y %H:%M:%S"
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(asctime)s %(message)s", datefmt=date_strftime_format)

In [2]:
# Reloads any code changes to 
importlib.reload(rosetta_data_loader)
importlib.reload(digs_data_loader)
importlib.reload(se3_diffuser)
importlib.reload(so3_diffuser)
importlib.reload(r3_diffuser)
importlib.reload(du)
importlib.reload(reverse_se3_diffusion)
importlib.reload(train_se3_diffusion)

<module 'experiments.train_se3_diffusion' from '/data/rsg/chemistry/jyim/projects/protein_diffusion/experiments/train_se3_diffusion.py'>

In [3]:
# Load config.
conf = OmegaConf.load('../config/base.yaml')

# Redirect cache from notebook directory.
exp_conf = conf.experiment
exp_conf.data_location = 'rosetta'
exp_conf.ckpt_dir = None
exp_conf.num_loader_workers = 0
exp_conf.dist_mode = 'single'
exp_conf.use_wandb = False

# Data settings
data_conf = conf.data
data_conf.rosetta.filtering.subset = 1
data_conf.rosetta.filtering.max_len = 80

# Diffusion settings
diff_conf = conf.diffuser
diff_conf.diffuse_trans = True  # whether to diffuse translations
diff_conf.diffuse_rot = True  # whether to diffuse rotations
# Noise schedules
diff_conf.rot_schedule = 'linear'
diff_conf.trans_schedule = 'exponential'

diff_conf.trans_align_t = True

# print(OmegaConf.to_yaml(conf))

### Load data

In [4]:
# Figure out data loading for PDB on rosetta
exp = train_se3_diffusion.Experiment(conf=conf)
train_loader, valid_loader = exp.create_rosetta_dataset(0, 1)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
exp._model = exp._model.to(device)

INFO: Number of model parameters 1464719
INFO: Using cached IGSO3.
INFO: Checkpoint not being saved.
INFO: Evaluation saved to: ./results/baseline/23D_09M_2022Y_18h_45m_41s
INFO: Training: 1 examples
INFO: Validation: 4 examples with lengths [64 64 64 64 64 64 64 64 64 64]


In [5]:
train_iter = iter(train_loader)
next_item = next(train_iter)
next_item = tree.map_structure(lambda x: x[0], next_item)

In [6]:
# Extract raw features
atom37_pos = next_item['atom37_pos']
res_mask =  next_item['res_mask']
bb_pos = atom37_pos[:5, :4, :]
# unpadded_bb_pos = bb_pos[res_mask.bool()]
unpadded_bb_pos = du.move_to_np(bb_pos)

In [None]:
# Visualize backbone structure
fig = plt.figure(figsize=(10, 10))
ax = plt.axes(projection='3d')

plotting.plt_3d(unpadded_bb_pos[:, 0], ax, color='r', s=50, mode='scatter')
plotting.plt_3d(unpadded_bb_pos[:, 1], ax, color='b', s=100, mode='scatter')
plotting.plt_3d(unpadded_bb_pos[:, 1], ax, color='b', mode='line')
plotting.plt_3d(unpadded_bb_pos[:, 2], ax, color='g', s=50, mode='scatter')
plotting.plt_3d(unpadded_bb_pos[:, 3], ax, color='purple', s=50, mode='scatter')


### Step 1: Calculate angles

In [8]:
def calculate_neighbor_angles(R_ac, R_ab):
    """Calculate angles between atoms c <- a -> b.

    Parameters
    ----------
        R_ac: Tensor, shape = (N,3)
            Vector from atom a to c.
        R_ab: Tensor, shape = (N,3)
            Vector from atom a to b.

    Returns
    -------
        angle_cab: Tensor, shape = (N,)
            Angle between atoms c <- a -> b.
    """
    # cos(alpha) = (u * v) / (|u|*|v|)
    x = torch.sum(R_ac * R_ab, dim=1)  # shape = (N,)
    # sin(alpha) = |u x v| / (|u|*|v|)
    y = torch.cross(R_ac, R_ab).norm(dim=-1)  # shape = (N,)
    # avoid that for y == (0,0,0) the gradient wrt. y becomes NaN
    y = torch.max(y, torch.tensor(1e-9))  
    angle = torch.atan2(y, x)
    return angle

def vector_projection(R_ab, P_n):
    """
    Project the vector R_ab onto a plane with normal vector P_n.

    Parameters
    ----------
        R_ab: Tensor, shape = (N,3)
            Vector from atom a to b.
        P_n: Tensor, shape = (N,3)
            Normal vector of a plane onto which to project R_ab.

    Returns
    -------
        R_ab_proj: Tensor, shape = (N,3)
            Projected vector (orthogonal to P_n).
    """
    a_x_b = torch.sum(R_ab * P_n, dim=-1)
    b_x_b = torch.sum(P_n * P_n, dim=-1)
    return R_ab - (a_x_b / b_x_b)[:, None] * P_n

In [12]:
# TODO: DOUBLE CHECK ALL CALCULATIONS!!

# GemNet diagram:
#  c     d
#  |     |
#  |     |
#  a - - b 

# Extract vectors
ca_pos = bb_pos[..., 1, :]
non_ca_pos = bb_pos[..., [0, 1, 2], :]

# DISPLACEMENT CONVENTION:
# - first index: starting position
# - second index: ending position
#
# C-alpha pairwise displacement vectors: [N, N, 3]
# calpha_vecs[a, b]: a --> b
# calpha_vecs[b, a]: a <-- b
calpha_vecs = (ca_pos[None, :, :] - ca_pos[:, None, :]) + 1e-10
# Backbone pairwise displacements: [N, 3, 3]
bb_vecs = (ca_pos[:, None, :] - non_ca_pos) + 1e-10

# Calculate phi angles: [N, N, 3]
# where 3 is the number of non-Ca atoms per residue.
# phi_angles[a, b, c] is angle:
# c
# ^
# |
# |
# a - - > b
num_res, num_non_ca, _ = bb_vecs.shape
tiled_calpha_vecs = torch.tile(calpha_vecs[:, :, None, :], (1, 1, 3, 1))
tiled_bb_vecs = torch.tile(bb_vecs[:, None, :, :], (1, num_res, 1, 1))
phi_angles = calculate_neighbor_angles(
    tiled_calpha_vecs.reshape(-1, 3),
    tiled_bb_vecs.reshape(-1, 3)
).reshape(num_res, num_res, num_non_ca)

In [36]:
# Projection of backbone vectors onto C-alpha plane: [N, N, 3, 3]
# proj_vecs[a, b, c]: 
# projection of a --> c onto plane with norm vector a --> b
proj_vecs_1 = vector_projection(
    tiled_bb_vecs.reshape(-1, 3),
    tiled_calpha_vecs.reshape(-1, 3)
).reshape(num_res, num_res, num_non_ca, 3)
# Flip order of C-alpha residues so that we project onto the opposite plane.
proj_vecs_2 = vector_projection(
    tiled_bb_vecs.transpose(0, 1).reshape(-1, 3),
    tiled_calpha_vecs.reshape(-1, 3)
).reshape(num_res, num_res, num_non_ca, 3)

# Calculate dihedral angle: [N, N, 3, 3]
# Should be symmetric: theta_angles[a, b] == theta_angles[b, a].T
proj_ac = torch.tile(proj_vecs_1[:, :, :, None, :], (1, 1, 1, 3, 1))
proj_bd = torch.tile(proj_vecs_2[:, :, None, :, :], (1, 1, 3, 1, 1))
theta_angles = calculate_neighbor_angles(
    proj_ac.reshape(-1, 3),
    proj_bd.reshape(-1, 3),
).reshape(num_res, num_res, num_non_ca, num_non_ca)

### Step 2: Embed with basis vectors

In [37]:
class Envelope(torch.nn.Module):
    """
    Envelope function that ensures a smooth cutoff.

    Parameters
    ----------
        p: int
            Exponent of the envelope function.
    """

    def __init__(self, p, name="envelope"):
        super().__init__()
        assert p > 0
        self.p = p
        self.a = -(self.p + 1) * (self.p + 2) / 2
        self.b = self.p * (self.p + 2)
        self.c = -self.p * (self.p + 1) / 2

    def forward(self, d_scaled):
        env_val = (
            1
            + self.a * d_scaled ** self.p
            + self.b * d_scaled ** (self.p + 1)
            + self.c * d_scaled ** (self.p + 2)
        )
        return torch.where(d_scaled < 1, env_val, torch.zeros_like(d_scaled))


In [39]:
class SphericalBasisLayer(torch.nn.Module):
    """
    2D Fourier Bessel Basis

    Parameters
    ----------
    num_spherical: int
        Controls maximum frequency.
    num_radial: int
        Controls maximum frequency.
    cutoff: float
        Cutoff distance in Angstrom.
    envelope_exponent: int = 5
        Exponent of the envelope function.
    efficient: bool
        Whether to use the (memory) efficient implementation or not.
    """

    def __init__(
        self,
        num_spherical: int = 7,
        num_radial: int = 6,
        cutoff: float = 5,
        envelope_exponent: int = 5,
        efficient: bool = False,
        name: str = "spherical_basis",
    ):
        super().__init__()

        assert num_radial <= 64
        self.efficient = efficient
        self.num_radial = num_radial
        self.num_spherical = num_spherical
        self.envelope = Envelope(envelope_exponent)
        self.inv_cutoff = 1 / cutoff

        # retrieve formulas
        bessel_formulas = basis_utils.bessel_basis(num_spherical, num_radial)
        Y_lm = basis_utils.real_sph_harm(
            num_spherical, spherical_coordinates=True, zero_m_only=True
        )
        self.sph_funcs = []  # (num_spherical,)
        self.bessel_funcs = []  # (num_spherical * num_radial,)
        self.norm_const = self.inv_cutoff ** 1.5
        self.register_buffer(
            "device_buffer", torch.zeros(0), persistent=False
        )  # dummy buffer to get device of layer

        # convert to torch functions
        x = sym.symbols("x")
        theta = sym.symbols("theta")
        modules = {"sin": torch.sin, "cos": torch.cos, "sqrt": torch.sqrt}
        m = 0  # only single angle
        for l in range(len(Y_lm)):  # num_spherical
            if l == 0: 
                # Y_00 is only a constant -> function returns value and not tensor
                first_sph = sym.lambdify([theta], Y_lm[l][m], modules)
                self.sph_funcs.append(
                    lambda theta: torch.zeros_like(theta) + first_sph(theta)
                )
            else:
                self.sph_funcs.append(sym.lambdify([theta], Y_lm[l][m], modules))
            for n in range(num_radial):
                self.bessel_funcs.append(
                    sym.lambdify([x], bessel_formulas[l][n], modules)
                )

    # def forward(self, D_ca, Angle_cab, id3_reduce_ca, Kidx):
    def forward(self, D_ca, Angle_cab):

        d_scaled = D_ca * self.inv_cutoff  # (nEdges,)
        u_d = self.envelope(d_scaled)
        rbf = [f(d_scaled) for f in self.bessel_funcs]
        # s: 0 0 0 0 1 1 1 1 ...
        # r: 0 1 2 3 0 1 2 3 ...
        rbf = torch.stack(rbf, dim=1)  # (nEdges, num_spherical * num_radial)
        rbf = rbf * self.norm_const
        rbf_env = u_d[:, None] * rbf  # (nEdges, num_spherical * num_radial)

        sph = [f(Angle_cab) for f in self.sph_funcs]
        sph = torch.stack(sph, dim=1)  # (nTriplets, num_spherical)

        if not self.efficient:
            # TODO: Need to broadcast.
            # rbf_env = rbf_env[id3_reduce_ca]  # (nTriplets, num_spherical * num_radial)
            rbf_env = rbf_env.view(-1, self.num_spherical, self.num_radial)
            # e.g. num_spherical = 3, num_radial = 2
            # z_ln: l: 0 0  1 1  2 2
            #       n: 0 1  0 1  0 1
            sph = sph.view(-1, self.num_spherical, 1)  # (nTriplets, num_spherical, 1)
            # e.g. num_spherical = 3, num_radial = 2
            # Y_lm: l: 0 0  1 1  2 2
            #       m: 0 0  0 0  0 0
            out = (rbf_env * sph).view(-1, self.num_spherical * self.num_radial)
            return out  # (nTriplets, num_spherical * num_radial)
        else:
            rbf_env = rbf_env.view(-1, self.num_spherical, self.num_radial)
            rbf_env = torch.transpose(
                rbf_env, 0, 1
            )  # (num_spherical, nEdges, num_radial)

            # Zero padded dense matrix
            # maximum number of neighbors, catch empty id_reduce_ji with maximum
            Kmax = 0 if sph.shape[0]==0 else torch.max(torch.max(Kidx + 1), torch.tensor(0))  
            nEdges = d_scaled.shape[0]

            sph2 = torch.zeros(
                nEdges, Kmax, self.num_spherical, device=self.device_buffer.device, dtype=sph.dtype
            )
            sph2[id3_reduce_ca, Kidx] = sph

            # (num_spherical, nEdges, num_radial), (nEdges, Kmax, num_spherical)
            return rbf_env, sph2

In [62]:

class TensorBasisLayer(torch.nn.Module):
    """
    3D Fourier Bessel Basis

    Parameters
    ----------
    num_spherical: int
        Controls maximum frequency.
    num_radial: int
        Controls maximum frequency.
    cutoff: float
        Cutoff distance in Angstrom.
    envelope_exponent: int = 5
        Exponent of the envelope function.
    """

    def __init__(
        self,
        num_spherical: int = 7,
        num_radial: int = 6,
        cutoff: float = 5,
        envelope_exponent: int = 5,
        name: str = "tensor_basis",
    ):
        super().__init__()

        assert num_radial <= 64
        self.num_radial = num_radial
        self.num_spherical = num_spherical

        self.inv_cutoff = 1 / cutoff
        self.envelope = Envelope(envelope_exponent)

        # retrieve formulas
        bessel_formulas = basis_utils.bessel_basis(num_spherical, num_radial)
        Y_lm = basis_utils.real_sph_harm(
            num_spherical, spherical_coordinates=True, zero_m_only=False
        )
        self.sph_funcs = []  # (num_spherical**2,)
        self.bessel_funcs = []  # (num_spherical * num_radial,)
        self.norm_const = self.inv_cutoff ** 1.5

        # convert to tensorflow functions
        x = sym.symbols("x")
        theta = sym.symbols("theta")
        phi = sym.symbols("phi")
        modules = {"sin": torch.sin, "cos": torch.cos, "sqrt": torch.sqrt}
        for l in range(len(Y_lm)):  # num_spherical
            for m in range(len(Y_lm[l])):
                if (
                    l == 0
                ):  # Y_00 is only a constant -> function returns value and not tensor
                    first_sph = sym.lambdify([theta, phi], Y_lm[l][m], modules)
                    self.sph_funcs.append(
                        lambda theta, phi: torch.zeros_like(theta)
                        + first_sph(theta, phi)
                    )
                else:
                    self.sph_funcs.append(
                        sym.lambdify([theta, phi], Y_lm[l][m], modules)
                    )
            for j in range(num_radial):
                self.bessel_funcs.append(
                    sym.lambdify([x], bessel_formulas[l][j], modules)
                )

        self.register_buffer(
            "degreeInOrder", torch.arange(num_spherical) * 2 + 1, persistent=False
        ) 

    def forward(self, D_ca, Alpha_cab, Theta_cabd):

        d_scaled = D_ca * self.inv_cutoff
        u_d = self.envelope(d_scaled)

        rbf = [f(d_scaled) for f in self.bessel_funcs]
        # s: 0 0 0 0 1 1 1 1 ...
        # r: 0 1 2 3 0 1 2 3 ...
        rbf = torch.stack(rbf, dim=1)  # (nEdges, num_spherical * num_radial)
        rbf = rbf * self.norm_const

        rbf_env = u_d[:, None] * rbf  # (nEdges, num_spherical * num_radial)
        rbf_env = rbf_env.view(
            (-1, self.num_spherical, self.num_radial)
        )  # (nEdges, num_spherical, num_radial)
        rbf_env = torch.repeat_interleave(
            rbf_env, self.degreeInOrder, dim=1
        )  # (nEdges, num_spherical**2, num_radial)

        rbf_env = rbf_env.view(
            (-1, self.num_spherical ** 2 * self.num_radial)
        )  # (nEdges, num_spherical**2 * num_radial)
        # rbf_env = rbf_env[
        #     id4_reduce_ca
        # ]  # (nQuadruplets, num_spherical**2 * num_radial)
        # e.g. num_spherical = 3, num_radial = 2
        # j_ln: l: 0  0    1  1  1  1  1  1    2  2  2  2  2  2  2  2  2  2
        #       n: 0  1    0  1  0  1  0  1    0  1  0  1  0  1  0  1  0  1

        sph = [f(Alpha_cab, Theta_cabd) for f in self.sph_funcs]
        sph = torch.stack(sph, dim=1)  # (nQuadruplets, num_spherical**2)

        sph = torch.repeat_interleave(
            sph, self.num_radial, axis=1
        )  # (nQuadruplets, num_spherical**2 * num_radial)
        # e.g. num_spherical = 3, num_radial = 2
        # Y_lm: l: 0  0    1  1  1  1  1  1    2  2  2  2  2  2  2  2  2  2
        #       m: 0  0   -1 -1  0  0  1  1   -2 -2 -1 -1  0  0  1  1  2  2
        return rbf_env * sph  # (nQuadruplets, num_spherical**2 * num_radial)

In [None]:
class BesselBasisLayer(torch.nn.Module):
    """
    1D Bessel Basis

    Parameters
    ----------
    num_radial: int
        Controls maximum frequency.
    cutoff: float
        Cutoff distance in Angstrom.
    envelope_exponent: int = 5
        Exponent of the envelope function.
    """

    def __init__(
        self,
        num_radial: int = 6,
        cutoff: float = 5,
        envelope_exponent: int = 5,
        name="bessel_basis",
    ):
        super().__init__()
        self.num_radial = num_radial
        self.inv_cutoff = 1 / cutoff
        self.norm_const = (2 * self.inv_cutoff) ** 0.5

        self.envelope = Envelope(envelope_exponent)

        # Initialize frequencies at canonical positions
        self.frequencies = torch.nn.Parameter(
            data=torch.Tensor(
                np.pi * np.arange(1, self.num_radial + 1, dtype=np.float32)
            ),
            requires_grad=True,
        )

    def forward(self, d):
        d = d[:, None]  # (nEdges,1)
        d_scaled = d * self.inv_cutoff
        env = self.envelope(d_scaled)
        return env * self.norm_const * torch.sin(self.frequencies * d_scaled) / d

In [70]:
cbf_basis = SphericalBasisLayer()
cbf_basis3 = SphericalBasisLayer()
sbf_basis = TensorBasisLayer()
rbf_basis = BesselBasisLayer()

NameError: name 'BesselBasisLayer' is not defined

In [None]:
calpha_dists = torch.linalg.norm(calpha_vecs, axis=-1)

cbf4 = cbf_basis(
    torch.tile(calpha_dists[..., None], (1, 1, 3)).ravel(),
    phi_angles.ravel()
).reshape(num_res, num_res, 3, -1)

sbf4 = sbf_basis(
    torch.tile(calpha_dists[..., None, None], (1, 1, 3, 3)).ravel(),
    torch.tile(phi_angles[..., None], (1, 1, 1, 3)).ravel(),
    theta_angles.ravel()
).reshape(num_res, num_res, )

rbf = rbf_basis(calpha_dists.ravel()).reshape