In [1]:
# %pip install Bio
# %pip install geoopt
# %pip install flow_matching

In [19]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch import nn, Tensor

import time
import math
import numpy as np
from Bio.PDB import PDBParser

# from geoopt import Euclidean

from flow_matching.path import GeodesicProbPath
from flow_matching.path.scheduler import CondOTScheduler
from flow_matching.solver import ODESolver, RiemannianODESolver
from flow_matching.utils import ModelWrapper
from flow_matching.utils.manifolds import FlatTorus, Manifold

# visualization
import matplotlib.pyplot as plt
from matplotlib import cm

if torch.cuda.is_available():
    device = 'cuda:0'
    print('Using gpu')
else:
    device = 'cpu'
    print('Using cpu.')

Using gpu


In [11]:

# -- 1) LOAD SINGLE-RESIDUE LYSINE FROM PDB --
pdb_parser = PDBParser(QUIET=True)
structure = pdb_parser.get_structure("lys", "lysine.pdb")

# Assume there's exactly one model, one chain, one residue
model = structure[0]
chain = next(model.get_chains())
res = next(chain.get_residues())

# -- 2) EXTRACT BACKBONE ATOMS & BUILD LOCAL BACKBONE FRAME --

# Grab coordinates of N, CA, and C
N_coord  = np.array(res["N"].coord)
CA_coord = np.array(res["CA"].coord)
C_coord  = np.array(res["C"].coord)

def build_backbone_frame(N, CA, C):
    """
    Build a local coordinate frame (3x3 rotation + origin) from 
    the backbone atoms N, CA, C.

    One common convention:
    - Origin = CA
    - x-axis = (CA -> C) normalized
    - y-axis = projection of (CA -> N) orthogonal to x, then normalized
    - z-axis = x cross y
    Returns:
       R (3x3 np.array) - rotation matrix
       t (3, )          - translation (the CA coordinates)
    """
    # Origin at CA
    origin = CA

    # x-axis: from CA to C (normalized)
    x = C - CA
    x /= np.linalg.norm(x)

    # provisional y-axis: from CA to N
    y = N - CA
    # remove component along x
    y -= (np.dot(y, x) * x)
    y /= np.linalg.norm(y)

    # z-axis: cross(x, y)
    z = np.cross(x, y)
    z /= np.linalg.norm(z)

    R = np.stack([x, y, z], axis=1)  # shape (3, 3)
    t = origin
    return R, t

R_frame, t_frame = build_backbone_frame(N_coord, CA_coord, C_coord)

# Convert to torch tensors
R_frame_torch = torch.tensor(R_frame, dtype=torch.float32)  # (3x3)
t_frame_torch = torch.tensor(t_frame, dtype=torch.float32)  # (3,)

# -- 3) EXTRACT CA COORDINATES (as Torch tensor) --
ca_torch = torch.tensor(CA_coord, dtype=torch.float32)

# -- 4) COMPUTE SIDECHAIN CHI ANGLES FOR LYS --
# Lys sidechain atoms: 
#    χ1: N - CA - CB - CG
#    χ2: CA - CB - CG - CD
#    χ3: CB - CG - CD - CE
#    χ4: CG - CD - CE - NZ

def dihedral_angle(a, b, c, d):
    """
    Compute dihedral angle in radians for four points 
    (each one is an np.array of shape (3,)).
    Formula based on cross/cross method.
    """
    b1 = b - a
    b2 = c - b
    b3 = d - c
    
    # normal to plane 1
    n1 = np.cross(b1, b2)
    # normal to plane 2
    n2 = np.cross(b2, b3)
    
    # normalize
    n1 /= np.linalg.norm(n1)
    n2 /= np.linalg.norm(n2)
    
    # direction of b2 for sign of angle
    m1 = np.cross(n1, b2 / np.linalg.norm(b2))
    x = np.dot(n1, n2)
    y = np.dot(m1, n2)
    angle = -math.atan2(y, x)  # negative to match common convention
    return angle

atom_names = ["N","CA","CB","CG","CD","CE","NZ"]
coords = {an: np.array(res[an].coord) for an in atom_names}

chi1 = dihedral_angle(coords["N"],  coords["CA"], coords["CB"], coords["CG"])
chi2 = dihedral_angle(coords["CA"], coords["CB"], coords["CG"], coords["CD"])
chi3 = dihedral_angle(coords["CB"], coords["CG"], coords["CD"], coords["CE"])
chi4 = dihedral_angle(coords["CG"], coords["CD"], coords["CE"], coords["NZ"])

# Store sidechain chi angles in a torch tensor (radians)
chi_torch = torch.tensor([chi1, chi2, chi3, chi4], dtype=torch.float32)

# -- 5) Print or use the results --
print("Backbone rotation (R_frame_torch):\n", R_frame_torch)
print("Backbone origin (t_frame_torch):\n", t_frame_torch)
print("C-alpha coordinate:", ca_torch)
print("Chi angles (radians):", chi_torch)


Backbone rotation (R_frame_torch):
 tensor([[-0.8000,  0.6000,  0.0000],
        [ 0.6000,  0.8000,  0.0000],
        [ 0.0000,  0.0000, -1.0000]])
Backbone origin (t_frame_torch):
 tensor([0.0000, 1.2000, 0.0000])
C-alpha coordinate: tensor([0.0000, 1.2000, 0.0000])
Chi angles (radians): tensor([-0.6559, -1.7392, -3.1416, -3.1416])


In [12]:
class FourierFeatures(nn.Module):
    """Assumes input is in [0, 2pi]."""

    def __init__(self, n_fourier_features: int):
        super().__init__()
        self.n_fourier_features = n_fourier_features

    def forward(self, x: Tensor) -> Tensor:
        feature_vector = [
            torch.sin((i + 1) * x) for i in range(self.n_fourier_features)
        ]
        feature_vector += [
            torch.cos((i + 1) * x) for i in range(self.n_fourier_features)
        ]
        return torch.cat(feature_vector, dim=-1)


class ProjectToTangent(nn.Module):
    """Projects a vector field onto the tangent plane at the input."""

    def __init__(self, vecfield: nn.Module, manifold: Manifold):
        super().__init__()
        self.vecfield = vecfield
        self.manifold = manifold

    def forward(self, x: Tensor, t: Tensor) -> Tensor:
        x = self.manifold.projx(x)
        v = self.vecfield(x, t)
        v = self.manifold.proju(x, v)
        return v

In [13]:


# Let's define a 2-layer MLP that:
# (1) has an input dimension of 16
# (2) has a hidden dimension (e.g. 32)
# (3) outputs dimension 16

class BackboneNet(nn.Module):
    def __init__(self, hidden_dim=32):
        super().__init__()
        self.fc1 = nn.Linear(16, hidden_dim)  # fully-connected layer 1
        self.fc2 = nn.Linear(hidden_dim, 16)  # fully-connected layer 2
        self.relu = nn.ReLU()
        
    def forward(self, x):
        # x shape: (batch_size, 16)
        h = self.fc1(x)         # shape: (batch_size, hidden_dim)
        h = self.relu(h)
        out = self.fc2(h)       # shape: (batch_size, 16)
        return out


In [22]:
def inf_train_gen(backbone_coords, frame,chi_angles,batch_size: int = 200, device: str = "cpu"):
    random_backbone = torch.randn((batch_size, backbone_coords.shape),3, device=device)
    random_frame = (torch.randn((batch_size,frame.shape), device=device)) #TODO cap these some how
    random_chi_angles = torch.randn((batch_size,frame.shape), device=device)

    return random_backbone, random_frame, random_chi_angles

def wrap(manifold, samples):
    center = torch.zeros_like(samples)

    return manifold.expmap(center, samples)

In [24]:
# training arguments
lr = 0.001
batch_size = 4096
iterations = 5001
print_every = 1000
manifold = FlatTorus()
dim = 2
hidden_dim = 16


## TODO get rid of projectToTangent if we use RCFM loss and a geodesic map?
# velocity field model init
vf = ProjectToTangent(  # Ensures we can just use Euclidean divergence.
    BackboneNet(  # Vector field in the ambient space.
        hidden_dim=hidden_dim,
    ),
    manifold=manifold,
)
vf.to(device)

# instantiate an affine path object
# TODO add linear and SO3 paths
path = GeodesicProbPath(scheduler=CondOTScheduler(), manifold=manifold)

# init optimizer
optim = torch.optim.Adam(vf.parameters(), lr=lr) 

# train
start_time = time.time()
for i in range(iterations):
    optim.zero_grad() 

    # sample data (user's responsibility): in this case, (X_0,X_1) ~ pi(X_0,X_1) = N(X_0|0,I)q(X_1)
    x_0_t, x_0_r , x_0_chi = inf_train_gen(t_frame_torch,R_frame_torch,chi_torch,batch_size=batch_size, device=device) # sample data

    # sample time (user's responsibility)
    t = torch.rand(x_1.shape[0]).to(device) 

    # sample probability path
    path_sample = path.sample(t=t, x_0=x_0, x_1=x_1)

    # flow matching l2 loss
    loss = torch.pow( vf(path_sample.x_t,path_sample.t) - path_sample.dx_t, 2).mean()

    # optimizer step
    loss.backward() # backward
    optim.step() # update
    
    # log loss
    if (i+1) % print_every == 0:
        elapsed = time.time() - start_time
        print('| iter {:6d} | {:5.2f} ms/step | loss {:8.3f} ' 
              .format(i+1, elapsed*1000/print_every, loss.item())) 
        start_time = time.time()

TypeError: randn() received an invalid combination of arguments - got (tuple, int, device=str), but expected one of:
 * (tuple of ints size, *, torch.Generator generator, tuple of names names, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
 * (tuple of ints size, *, torch.Generator generator, Tensor out = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
 * (tuple of ints size, *, Tensor out = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
 * (tuple of ints size, *, tuple of names names, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)


In [15]:
class FourierFeatures(nn.Module):
    """Assumes input is in [0, 2pi]."""

    def __init__(self, n_fourier_features: int):
        super().__init__()
        self.n_fourier_features = n_fourier_features

    def forward(self, x: Tensor) -> Tensor:
        feature_vector = [
            torch.sin((i + 1) * x) for i in range(self.n_fourier_features)
        ]
        feature_vector += [
            torch.cos((i + 1) * x) for i in range(self.n_fourier_features)
        ]
        return torch.cat(feature_vector, dim=-1)


class ProjectToTangent(nn.Module):
    """Projects a vector field onto the tangent plane at the input."""

    def __init__(self, vecfield: nn.Module, manifold: Manifold):
        super().__init__()
        self.vecfield = vecfield
        self.manifold = manifold

    def forward(self, x: Tensor, t: Tensor) -> Tensor:
        x = self.manifold.projx(x)
        v = self.vecfield(x, t)
        v = self.manifold.proju(x, v)
        return v