In [9]:
import json
import torch
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils
from torch_geometric.data import Data
from torch.utils.data import Dataset, DataLoader

ValueError: numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject

In [12]:
#mamba install pyg=*=*cu* -c pyg
!mamba update pandas

^C

CondaError: KeyboardInterrupt



In [82]:
# GNN Model for DNA Origami
class GNNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNNModel, self).__init__()
        self.conv1 = pyg_nn.GCNConv(input_dim, hidden_dim)
        self.conv2 = pyg_nn.GCNConv(hidden_dim, hidden_dim)
        self.conv3 = pyg_nn.GCNConv(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.relu(self.conv1(x, edge_index))
        x = self.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        return x


In [79]:
# Custom Dataset for DNA configurations
class DNADataset(Dataset):
    def __init__(self, data, edge_index):
        self.data = data
        self.edge_index = edge_index

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = torch.tensor(self.data[idx], dtype=torch.float32)
        edge_index = torch.tensor(self.edge_index, dtype=torch.long)
        return Data(x=x, edge_index=edge_index)

In [104]:
def calc_angles(x):
    """
    Computes the angles between consecutive nucleotides in the configuration.
    Args:
        x: Tensor of shape (batch_size, num_nucleotides, dim), where dim is usually 3, i.e. (x, y, z)

    Returns:
        Angles: Tensor of angles between three consecutive nucleotides.
    """
    # Vector between nucleotide i and i+1
    v1 = x[:, :-2] - x[:, 1:-1]
    # Vector between nucleotide i+1 and i+2
    v2 = x[:, 1:-1] - x[:, 2:]

    # Compute the angle between vectors v1 and v2 using dot product and norm
    v1_norm = torch.norm(v1, dim=-1)
    v2_norm = torch.norm(v2, dim=-1)

    dot_prod = (v1 * v2).sum(dim=-1)
    cos_theta = dot_prod / (v1_norm * v2_norm)
    cos_theta = torch.clamp(cos_theta, -1.0, 1.0)  # Clamp to avoid invalid values in acos
    # angles in radians
    angles = torch.acos(cos_theta)

    return angles

In [1]:
def epotential_backbone(x, R0=1.5, k=30):
    r_ij = torch.norm(x[:, :-1] - x[:, 1:], dim=-1) + 1e-9
    valid_mask = r_ij < R0  # Ensure distances are within valid range for log
    r_ij = torch.clamp(r_ij, max=R0 - 1e-6)  # Avoid log of zero or negative values
    potential = (k / 2) * R0**2 * torch.log(1 - (r_ij**2) / R0**2)
    potential = torch.where(valid_mask, potential, torch.zeros_like(potential))  # Zero out invalid values
    return potential.sum()


def epotential_stacking(x, epsilon_stack=1.0):
    r_ij = torch.norm(x[:, :-1] - x[:, 1:], dim=-1) + 1e-9
    theta = calc_angles(x)
    f1 = torch.exp(-1 * r_ij**2)
    f2 = torch.cos(theta)
    return -1 * epsilon_stack * (f1 * f2).sum()

def hydrogen_bonding(x, base_pairs, epsilon_hb=2.0):
    hb_energy = 0
    batch_size = x.size(0)
    for i, j in base_pairs:
        if i < x.size(1) and j < x.size(1):  # Ensure indices are within bounds
            r_ij = torch.norm(x[:, i] - x[:, j], dim=-1) + 1e-9
            r_ij = torch.clamp(r_ij, min=1e-3)  # Avoid extremely small distances
            theta_ij = calc_angles(torch.stack([x[:, i], x[:, j]], dim=1).view(batch_size, -1, 3))
            hb_energy += -1 * epsilon_hb * torch.exp(-1 * r_ij) * torch.cos(theta_ij)
    return hb_energy

def excluded_volume(x, A=5.0, lambda_val=1.0):
    r_ij = torch.norm(x[:, :, None] - x[:, None, :], dim=-1) + 1e-9
    r_ij = torch.clamp(r_ij, min=1e-3)  # Avoid extremely small distances
    return A * torch.exp(-1 * r_ij / lambda_val).sum()

def coaxial_stacking(x, epsilon_coaxial=1.5):
    r_ij = torch.norm(x[:, :-1] - x[:, 1:], dim=-1) + 1e-9
    theta_ij = calc_angles(x)
    return -1 * epsilon_coaxial * torch.exp(-1 * r_ij) * torch.cos(theta_ij).sum()

def dna_energy_func(x, base_pairs):
    e_backbone = epotential_backbone(x)
    e_stacking = epotential_stacking(x)
    e_hb = hydrogen_bonding(x, base_pairs)
    e_excluded = excluded_volume(x)
    e_coaxial = coaxial_stacking(x)
    e_total = e_backbone + e_stacking + e_hb + e_excluded + e_coaxial
    return e_total

In [2]:
# KL Divergence Loss with Energy Function
def kl_divergence_loss(model, data, base_pairs, energy_func):
    x = model(data)
    energy = energy_func(x, base_pairs)
    kl_loss = energy.mean()
    if torch.isnan(kl_loss):
        print("NaN detected in KL loss! Debugging information:")
        print(f"Energy: {energy}")
    return kl_loss

In [None]:
# def monitor_gradients(model):
#     for name, param in model.named_parameters():
#         # print(f"Layer {name}; Gradient norm = {param.grad.norm()}")
#         if param.grad is not None:
#             print(f"Layer {name}; Gradient norm = {param.grad.norm()}")

In [None]:
# def initialize_weights(m):
#     if isinstance(m, nn.Linear):
#         torch.nn.init.xavier_uniform_(m.weight)
#         if m.bias is not None:
#             torch.nn.init.zeros_(m.bias) # init biases to zero

In [3]:
def train(model, dataloader, optimizer, scheduler, base_pairs, epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for data in dataloader:
            data = data.to(device)
            optimizer.zero_grad()
            loss = kl_divergence_loss(model, data, base_pairs, dna_energy_func)
            if torch.isnan(loss):
                print("Skipping batch due to NaN loss.")
                continue  # Skip this batch if loss is NaN
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)  # Reduce gradient clipping value
            optimizer.step()
            total_loss += loss.item()
        scheduler.step(total_loss / len(dataloader))  # Update the learning rate scheduler
        print(f"Epoch {epoch + 1}/{epochs} -> Loss: {total_loss / len(dataloader):.4f}")

In [5]:
# Sampling new DNA configurations from latent space
def sample_dna_configs(model, data):
    model.eval()
    with torch.no_grad():
        x_generated = model(data)
        return x_generated.cpu().numpy()

In [3]:
# Randomly generated data for demo
input_dim = 9  # Input dimension to include position, a1, and a3 vectors
hidden_dim = 128
output_dim = 3  # Output dimension (x, y, z)
learning_rate = 0.0001
batch_size = 16
epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [74]:
# with open("./data/output.oxview", 'r') as file:
#     json_data = json.load(file)

In [4]:
data_dir = './data'
position_data = []
base_pairs = []
monomer_index_map = {}  # Map to keep track of monomer IDs and their indices in the data array
idx = 0
edge_index = []
for system in json_data['systems']:
    for strand in system['strands']:
        for monomer in strand['monomers']:
            p = monomer['p'] + monomer['a1'] + monomer['a3']  # Concatenate position (x, y, z) with internal vectors a1 and a3 to get input_dim = 9
            position_data.append(p)
            monomer_index_map[monomer['id']] = idx
            idx += 1
            # Check if monomer has a base pair
            if 'bp' in monomer and monomer['bp'] in monomer_index_map:
                base_pairs.append((monomer_index_map[monomer['id']], monomer_index_map[monomer['bp']]))
            # Add edges for GNN (backbone connections)
            if idx > 0:
                edge_index.append([idx - 1, idx])
                edge_index.append([idx, idx - 1])
position_data = np.array(position_data)
edge_index = np.array(edge_index).T  # Transpose to match PyG format

NameError: name 'json_data' is not defined

In [76]:
#data = [torch.randn(1, input_dim) for _ in range(100)] # we will replace with actual dna coords/oxDNA datasets
# data = np.random.randn(1000, input_dim)
dataset = DNAOrigamiDataset(position_data)
# data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [103]:
 # Model, optimizer, and training
realnvp = RealNVP(input_dim, num_layers).to(device)
realnvp.apply(initialize_weights)
optimizer = optim.Adam(realnvp.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

# Sample base pairs (replace with realistic pairs for DNA)
# base_pairs = [(0, 7), (1, 6), (2, 5), (3, 4)]

# Train the model
train(realnvp, dataloader, optimizer, base_pairs, epochs)

Epoch 1/50 -> Loss: -247.1958
Epoch 2/50 -> Loss: -249.9794
Epoch 3/50 -> Loss: -249.9796
Epoch 4/50 -> Loss: -249.9796
Epoch 5/50 -> Loss: -249.9797


KeyboardInterrupt: 

In [52]:
new_dna_configs = sample_dna_configs(realnvp, num_samples=10, input_dim=input_dim)
print(new_dna_configs)

[[ 2.71860695e+00 -9.27713990e-01 -7.84954250e-01 -1.26351155e-02
  -1.15915164e-02 -8.71867780e-03]
 [-5.28648905e-02  9.03109014e-02 -1.98490396e-01 -2.02346407e-02
  -4.08926718e-02 -1.74785182e-02]
 [ 2.07511753e-01 -7.60903358e-01 -7.22222805e-01 -4.96835215e-03
  -8.53362028e-03 -2.73933727e-02]
 [ 6.45293742e-02 -6.42429411e-01 -1.44040453e+00  7.56772701e-03
   2.90152850e-03 -1.62262004e-02]
 [-1.92330942e-01  1.80773631e-01 -3.92905444e-01 -2.24663224e-03
  -3.06282844e-02 -1.42831663e-02]
 [ 1.18639970e+00  1.75999510e+00  4.19231027e-01 -6.43657744e-02
  -4.34665978e-02  1.04686674e-02]
 [-3.27538401e-01  1.34096992e+00 -1.08950031e+00  2.16821488e-03
  -1.66193694e-02  1.68389790e-02]
 [ 9.00311470e-02  1.66173005e+00 -2.48448133e-01 -3.01189553e-02
  -3.59693840e-02  1.19226221e-02]
 [-2.42928982e-01 -8.11607480e-01 -2.25642040e-01  6.29604375e-03
  -1.54456720e-02 -1.92963537e-02]
 [-1.17423284e+00 -1.62350559e+00  1.24952888e+00  9.48551856e-03
  -1.27476389e-02 -6.2960