In [1]:
import os
import math
import json
import torch
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [2]:
class AffineCouplingLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim=512):
        super(AffineCouplingLayer, self).__init__()
        split_dim = input_dim // 2 
        if input_dim % 2 !=0:
            split_dim += 1
            
        self.hidden_dim = hidden_dim
        print(f"hidden dim{self.hidden_dim}")

        # Added skip connections
        self.scale_net_fc1 = nn.Linear(split_dim, self.hidden_dim)
        self.scale_net_fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.scale_net_fc3 = nn.Linear(self.hidden_dim, split_dim)

        # Added skip connections
        self.translate_net_fc1 = nn.Linear(split_dim, self.hidden_dim)
        self.translate_net_fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.translate_net_fc3 = nn.Linear(self.hidden_dim, split_dim)

    def forward(self, x):
        # splitting input into 2 chunks
        # print(f"Input shape: {x.shape}")
        split_dim = x.size(1) // 2
        print(f"split dim:{split_dim}")
        print(f"x{x.shape}")
        if x.size(1) % 2 !=0:
            split_dim += 1
            
        x1 = x[:, :split_dim]
        x2 = x[:, split_dim:]
        print(f"x1{x1.shape}")
        print(f"x2{x2.shape}")
        # x1, x2 = x.chunk(2, dim=1)

        # print(f"x1 shape: {x1.shape}, x2 shape: {x2.shape}")
        # Scale network with skip connection
        hidden = torch.relu(self.scale_net_fc1(x1))
        hidden = torch.relu(self.scale_net_fc2(hidden))+hidden # skip connection
        scale = self.scale_net_fc3(hidden)

        # Translate network with skip connection
        hidden = torch.relu(self.translate_net_fc1(x1))
        hidden = torch.relu(self.translate_net_fc2(hidden))+hidden # skip connection
        translate = self.translate_net_fc3(hidden)
        
        scale = torch.clamp(scale, min=-5.0, max=5.0)
        # Apply affine transformations to x2
        z2 = x2 * torch.exp(scale) + translate
        print(f"{x1.shape()}")
        print(f"{x2.shape()}")
        # output = torch.cat([x1[:, :z2.size(1)], z2], dim=1)
        return torch.cat([x1, z2], dim=1), scale

    def inverse(self, z):
        split_dim = z.size(1) // 2
        z1 = z[:, :split_dim]
        z2 = z[:, split_dim:]
        # z1, z2 = z.chunk(2, dim=1)

        # Scale network with skip connection
        hidden = torch.relu(self.scale_net_fc1(z1))
        hidden = torch.relu(self.scale_net_fc2(hidden))+hidden
        scale = self.scale_net_fc3(hidden)

        # Translate network with skip connection
        hidden = torch.relu(self.translate_net_fc1(z1))
        hidden = torch.relu(self.translate_net_fc2(hidden))+hidden
        translate = self.translate_net_fc3(hidden)
        
        scale = torch.clamp(scale, min=-5.0, max=5.0)
        # Inverse affine transformation
        x2 = (z2 - translate) * torch.exp(-scale)
        return torch.cat([z1, x2], dim=1)

In [3]:
class RealNVP(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super(RealNVP, self).__init__()
        # if input_dim % 2 != 0:
        #     raise ValueError("Input dimension must be even for chunking.")
        # self.layers = nn.ModuleList([AffineCouplingLayer(input_dim, hidden_dim) for _ in range(num_layers)])
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(AffineCouplingLayer(input_dim, hidden_dim))
            self.layers.append(nn.BatchNorm1d(input_dim))

    def forward(self, x):
        log_det_jacobian = 0
        for layer in self.layers:
            if isinstance(layer, AffineCouplingLayer):
                x, scale = layer(x)
                print(f"scale:{scale}")
                log_det_jacobian += scale.sum(dim=1)
            else:
                x = layer(x)
        return x, log_det_jacobian

    def inverse(self, z): 
        for layer in reversed(self.layers):
            if isintance(layer, AffineCouplingLayer):
                z = layer.inverse(z)
            else: 
                z = layer(z)
        return z


class DNAOrigamiDataset(Dataset):
    def __init__(self, data):
        # List of configurations: coordinates and angles
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx], dtype=torch.float32)

In [4]:
# class AffineCouplingLayer(nn.Module):
#     def __init__(self, input_dim, hidden_dim=512):
#         super(AffineCouplingLayer, self).__init__()
#         split_dim = input_dim // 2
#         self.hidden_dim = hidden_dim

#         # Define the scale and translate networks
#         self.scale_net = nn.Sequential(
#             nn.Linear(split_dim, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, split_dim)
#         )

#         self.translate_net = nn.Sequential(
#             nn.Linear(split_dim, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, split_dim)
#         )

#     def forward(self, x):
#         split_dim = x.size(1) // 2
#         x1 = x[:, :split_dim]
#         x2 = x[:, split_dim:]

#         # Apply scale and translation
#         scale = self.scale_net(x1)
#         translate = self.translate_net(x1)
#         scale = torch.clamp(scale, min=-5.0, max=5.0)

#         # Affine transformation
#         z2 = x2 * torch.exp(scale) + translate
#         return torch.cat([x1, z2], dim=1), scale

#     def inverse(self, z):
#         split_dim = z.size(1) // 2
#         z1 = z[:, :split_dim]
#         z2 = z[:, split_dim:]

#         # Apply inverse scale and translation
#         scale = self.scale_net(z1)
#         translate = self.translate_net(z1)
#         scale = torch.clamp(scale, min=-5.0, max=5.0)

#         x2 = (z2 - translate) * torch.exp(-scale)
#         return torch.cat([z1, x2], dim=1)

# class RealNVP(nn.Module):
#     def __init__(self, input_dim, hidden_dim, num_layers):
#         super(RealNVP, self).__init__()
#         self.layers = nn.ModuleList()
#         for i in range(num_layers):
#             self.layers.append(AffineCouplingLayer(input_dim, hidden_dim))
#             # batch normalization for numerical stability
#             self.layers.append(nn.BatchNorm1d(input_dim))

#     def forward(self, x):
#         log_det_jacobian = 0
#         for layer in self.layers:
#             if isinstance(layer, AffineCouplingLayer):
#                 x, scale = layer(x)
#                 log_det_jacobian += scale.sum(dim=1)
#             else:
#                 x = layer(x)
#         return x, log_det_jacobian

#     def inverse(self, z):
#         for layer in reversed(self.layers):
#             if isinstance(layer, AffineCouplingLayer):
#                 z = layer.inverse(z)
#             else:
#                 z = layer(z)
#         return z

# class DNAOrigamiDataset(Dataset):
#     def __init__(self, data):
#         # List of configurations: coordinates and angles
#         self.data = data

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         return torch.tensor(self.data[idx], dtype=torch.float32)

In [5]:
def calc_angles(x):
    """
    Computes the angles between consecutive nucleotides in the configuration.
    Args:
        x: Tensor of shape (batch_size, num_nucleotides, dim), where dim is usually 3, i.e. (x, y, z)

    Returns:
        Angles: Tensor of angles between three consecutive nucleotides.
    """
    # Vector between nucleotide i and i+1
    v1 = x[:, :-2] - x[:, 1:-1]
    # Vector between nucleotide i+1 and i+2
    v2 = x[:, 1:-1] - x[:, 2:]

    # Compute the angle between vectors v1 and v2 using dot product and norm
    v1_norm = torch.norm(v1, dim=-1)
    v2_norm = torch.norm(v2, dim=-1)

    dot_prod = (v1 * v2).sum(dim=-1)
    cos_theta = dot_prod / (v1_norm * v2_norm * 1e-9)
    cos_theta = torch.clamp(cos_theta, -1.0, 1.0)  # Clamp to avoid invalid values in acos
    # angles in radians
    angles = torch.acos(cos_theta)

    return angles

In [6]:
def epotential_backbone(x, R0=1.5, k=30):
    # Potential based on dist between nucleotides; R0: max bond length, k: spring constant

    # Dist between adjacent nucleotides
    r_ij = torch.norm(x[:, :-1] - x[:, 1:], dim=-1) + 1e-9
    valid_mask = r_ij < R0  # Ensure distances are within valid range for log
    r_ij = torch.clamp(r_ij, max=R0 - 1e-6)  # Avoid log of zero or negative values
    return (k / 2) * R0**2 * torch.log(1 - (r_ij**2) / R0**2)[valid_mask].sum()
    # return (k/2) * R0**2 * torch.log(1 - (r_ij**2) / R0**2).sum()


def epotential_stacking(x, epsilon_stack=1.0):
    # Stacking potential for sequence-dependant interactions

    # Dist between adjacent bases
    r_ij = torch.norm(x[:, :-1] - x[:, 1:], dim=-1)

    # Angle between consecutive nucleotides
    theta = calc_angles(x)

    # Dist and Angle dependent terms
    f1 = torch.exp(-1 * r_ij**2)
    f2 = torch.cos(theta)

    return -1 * epsilon_stack * (f1 * f2).sum()


def hydrogen_bonding(x, base_pairs, epsilon_hb=2.0):
    # Hydrogen bonding between complementary base pairs
    hb_energy = 0
    batch_size = x.size(0)

    for i, j in base_pairs:
        if i < x.size(1) and j < x.size(1): # Ensure indices are within bounds
            # Dist between paired nucleotides
            r_ij = torch.norm(x[:, i] - x[:, j], dim=-1)
            r_ij = torch.clamp(r_ij, min=1e-3)  # Avoid extremely small distances
            # Angle between bases
            theta_ij = calc_angles(torch.stack([x[:, i], x[:, j]], dim=1).view(batch_size, -1, 3))
    
            hb_energy += -1 * epsilon_hb * torch.exp(-1*r_ij) * torch.cos(theta_ij)

    return hb_energy


def excluded_volume(x, A=5.0, lambda_val=1.0):
    # Excluded volume to prevent nucleotides from overlapping
    r_ij = torch.norm(x[:, :, None] - x[:, None, :], dim=-1)  # Pairwise distances

    return A * torch.exp(-1*r_ij / lambda_val).sum()


def coaxial_stacking(x, epsilon_coaxial=1.5):
    # Coaxial stacking interaction, important for junctions and nicks
    r_ij = torch.norm(x[:, :-1] - x[:, 1:], dim=-1)
    theta_ij = calc_angles(x) # calc angles at junctions

    return -1 * epsilon_coaxial * torch.exp(-1*r_ij) * torch.cos(theta_ij).sum()


def entropy_loss(x, scaffold_loops, kuhn_length=1.0):
    """
    Calculate the entropy loss based on scaffold loop configurations.
    Args:
        x: Tensor of shape (batch_size, num_nucleotides, dim), representing the coordinates of the DNA scaffold.
        scaffold_loops: List of tuples indicating start and end indices of each loop.
        kuhn_length: Approximate length of a Kuhn segment.

    Returns:
        Entropy: Scalar value representing the total entropy contribution for the given configuration.
    """
    R = 8.314  # Gas constant in J/(mol*K)
    total_entropy = 0.0

    for loop_start, loop_end in scaffold_loops:
        loop_length = loop_end - loop_start
        if loop_length > 1:
            # Calculate end-to-end vector r
            p_start = x[:, loop_start, :]
            p_end = x[:, loop_end, :]
            r = torch.norm(p_end - p_start, dim=-1).mean()  # Average over batch

            # Calculate num of Kuhn segments
            N = loop_length / kuhn_length
            loop_entropy = -R * math.log(N + 1e-9)  # avoid log(0)
            total_entropy += loop_entropy

    return total_entropy


# Cal Gibbs Free energy
def dna_energy_func(x, base_pairs, scaffold_loops, temperature=300):
    # Existing energy calculations
    e_backbone = epotential_backbone(x)
    e_stacking = epotential_stacking(x)
    e_hb = hydrogen_bonding(x, base_pairs)
    e_excluded = excluded_volume(x)
    e_coaxial = coaxial_stacking(x)

    # Calculate entropy loss
    s_entropy = entropy_loss(x, scaffold_loops)

    # Gibbs Free Energy: G = H - TS
    g_total = e_backbone + e_stacking + e_hb + e_excluded + e_coaxial - temperature * s_entropy
    
    return g_total

In [7]:
# KL Divergence Loss with Energy Function
def kl_divergence_loss(model, data, base_pairs, scaffold_loops, energy_func, temperature=300):
    x = model(data)
    z, log_det_jacobian = model.forward(data)  # Use the invertible property of Boltzmann generators to obtain z and Jacobian
    energy = energy_func(x, base_pairs, scaffold_loops, temperature)
    kl_loss = (energy - log_det_jacobian).mean()  # KL divergence with Jacobian correction
    if torch.isnan(kl_loss):
        print("NaN detected in KL loss!")
        print(f"Energy: {energy}")
        
    return kl_loss

In [8]:
def monitor_gradients(model):
    for name, param in model.named_parameters():
        # print(f"Layer {name}; Gradient norm = {param.grad.norm()}")
        if param.grad is not None:
            print(f"Layer {name}; Gradient norm = {param.grad.norm()}")

In [9]:
def initialize_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias) # init biases to zero

In [10]:
def identify_scaffold_loops(json_data):
    """
    Identifies scaffold loops in the DNA configuration based on base pairing and position data.
    
    Args:
        json_data: Parsed JSON data containing systems, strands, and monomer information.
    
    Returns:
        scaffold_loops: List of tuples representing (loop_start, loop_end) indices for each identified loop.
    """
    scaffold_loops = []
    monomer_data_map = {}  # Keep track of monomer data, including their index and position
    idx = 0

    for system in json_data['systems']:
        for strand in system['strands']:
            for monomer in strand['monomers']:
                monomer_data_map[monomer['id']] = {
                    'index': idx,
                    'p': monomer['p']
                }
                idx += 1

    # Reset index for processing loop identification
    idx = 0
    for system in json_data['systems']:
        for strand in system['strands']:
            for monomer in strand['monomers']:
                if 'bp' in monomer:
                    paired_id = monomer['bp']
                    if paired_id in monomer_data_map:
                        paired_idx = monomer_data_map[paired_id]['index']
                        # Identify if paired_idx is before idx, which may indicate a loop
                        if paired_idx > idx:
                            scaffold_loops.append((idx, paired_idx))
                        # if paired_idx < idx:
                        #     loop_start = paired_idx
                        #     loop_end = idx
                        #     # Calculate distance between paired nucleotides
                        #     p_start = torch.tensor(monomer_data_map[paired_id]['p'])
                        #     p_end = torch.tensor(monomer['p'])
                        #     distance = torch.norm(p_end - p_start).item()
                        #     # threshold to confirm if it's a loop
                        #     if distance < 5.0:  # Arbitrary threshold for identifying spatial proximity
                        #         scaffold_loops.append((loop_start, loop_end))
                idx += 1

    return scaffold_loops

In [11]:
# Boltzmann generator training
def train(model, dataloader, optimizer, scheduler, base_pairs, scaffold_loops, dna_energy_func, epochs, temperature=300):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for x in dataloader:
            # moving data to gpu
            x = x.to(device)
            # base_pairs = [(i, j) for i, j in base_pairs]
            # Add gaussian noise to input data
            noise = torch.randn_like(x) * 0.01
            x += noise

            optimizer.zero_grad()
            # Calc KL divergence loss
            loss = kl_divergence_loss(model, x, base_pairs, scaffold_loops, dna_energy_func, temperature)
            monitor_gradients(model)
            if torch.isnan(loss):
                continue 
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
            optimizer.step()
            # monitor_gradients(model)
            total_loss += loss.item()

        scheduler.step(total_loss/len(dataloader))
        print(f"Epoch {epoch+1}/{epochs} -> Loss: {total_loss / len(dataloader):.4f}")

In [12]:
# Sampling new DNA configurations from latent space
def sample_dna_configs(model, num_samples, input_dim):
    model.eval()
    with torch.no_grad():
        # Latent space dimension should match the input dimension of the RealNVP model
        # latent_dim = model.layers[0].scale_net[0].in_features * 2
        #z = torch.randn(num_samples, latent_dim, device=device)

        z = torch.randn(num_samples, input_dim).to(device)
        x_samples = model.inverse(z)

    return x_samples.cpu().numpy()

In [13]:
# Randomly generated data for demo
input_dim = 9
hidden_dim = 256
num_layers = 10
learning_rate = 0.0001
temperature = 300
batch_size = 16
epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [14]:
# with open("./data/output.oxview", 'r') as file:
#     json_data = json.load(file)

In [15]:
def pad_data(df, target_length=None):
    if target_length is None:
        target_length = max([len(x) for x in df])

    padded_data_list = []
    for x in df:
        current_length = len(x)
        if current_length < target_length:
            pad_length = target_length - current_length
            padding = np.zeros((pad_length, x.shape[1]))
            padded_data = np.vstack([x, padding])
        else:
            padded_data = x

        padded_data_list.append(padded_data)

    return padded_data_list

In [16]:
data_dir = './data'
data_list = []
base_pair_list = []
scaffold_loops_list = []

for filename in os.listdir(data_dir):
    if filename.endswith(".oxview") or filename.endswith(".json"):
        with open(os.path.join(data_dir, filename), 'r') as f:
            json_data = json.load(f)
            position_data = []
            base_pairs = []
            scaffolds_loops = []
            monomer_index_map = {} # Map to keep track of monomer IDs and their indices in the dataset

            idx = 0
            edge_index = []       
            for system in json_data['systems']:
                for strand in system['strands']:
                    for monomer in strand['monomers']:
                        p = monomer['p'] + monomer['a1'] + monomer['a3']
                        position_data.append(p)
                        monomer_index_map[monomer['id']] = idx
                        idx += 1
                        if 'bp' in monomer and monomer['bp'] in monomer_index_map:
                            base_pairs.append((monomer_index_map[monomer['id']], monomer_index_map[monomer['bp']]))

            scaffold_loops = identify_scaffold_loops(json_data)
            data_list.append(np.array(position_data))
            base_pair_list.append(base_pairs)
            scaffold_loops_list.append(scaffolds_loops)

data_list = pad_data(data_list)

In [17]:
print(f"dir:{os.getcwd()}")

dir:/home/sanbaras/sulcLab/origamiModels


In [18]:
#data = [torch.randn(1, input_dim) for _ in range(100)] # we will replace with actual dna coords/oxDNA datasets
# data = np.random.randn(1000, input_dim)
dataset = DNAOrigamiDataset(data_list)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [19]:
 # Model, optimizer, and training
realnvp = RealNVP(input_dim, hidden_dim, num_layers).to(device)
# realnvp.apply(initialize_weights)
optimizer = optim.Adam(realnvp.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

# Train the model
train(realnvp, dataloader, optimizer, scheduler, base_pair_list, scaffold_loops_list, epochs, temperature)

hidden dim256
hidden dim256
hidden dim256
hidden dim256
hidden dim256
hidden dim256
hidden dim256
hidden dim256
hidden dim256
hidden dim256
split dim:8117
xtorch.Size([2, 16235, 9])
x1torch.Size([2, 8118, 9])
x2torch.Size([2, 8117, 9])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (16236x9 and 5x256)

In [None]:
new_dna_configs = sample_dna_configs(realnvp, num_samples=10, input_dim=input_dim)
print(new_dna_configs)