In [34]:
from se3_transformer.model.layers.convolution import VersatileConvSE3
import importlib
importlib.reload(se3_transformer.model.layers.convolution)

<module 'se3_transformer.model.layers.convolution' from '/notebooks/SE3Transformer/se3_transformer/model/layers/convolution.py'>

In [46]:
from se3_transformer.model.transformer import Sequential


def fixed_sequential_forward(self, input, *args, **kwargs):
    print(f"\n--- Entering SE3 Sequential Block ---")
    for i, module in enumerate(self):
        print(f"  Input to Module {i} ({type(module).__name__}):")
        if isinstance(input, dict):
            for key, val in input.items():
                print(f"    features['{key}'].shape: {val.shape}")
        else:
            print(f"    input shape: {input.shape}") # Should likely be a dict

        input = module(input, *args, **kwargs) # Call the module

        print(f"  Output from Module {i} ({type(module).__name__}):")
        if isinstance(input, dict):
            for key, val in input.items():
                print(f"    features['{key}'].shape: {val.shape}")
        else:
             print(f"    output shape: {input.shape}") # Should likely be a dict
        print("-" * 20)
    print(f"--- Exiting SE3 Sequential Block ---\n")
    return input

Sequential.forward = fixed_sequential_forward
print("patched sequential forward")

patched sequential forward


In [47]:
import torch
from torch import Tensor
import torch.nn as nn
from typing import Dict

# Import the class to patch
from se3_transformer.model.layers.linear import LinearSE3

print("Attempting to monkeypatch LinearSE3.forward with einsum...")

# Store the original method if you might need it later (optional)
# original_linear_forward = LinearSE3.forward

# Define the patched function using einsum
def patched_linear_einsum_forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
    """
    Patched forward method for LinearSE3 using einsum for correct channel transformation.
    Input features expected shape: (N, D, C_in)
    Weights expected shape: (C_out, C_in)
    Output features shape: (N, D, C_out)
    """
    print(f"\n--- Patched LinearSE3.forward (Einsum) ---")
    print(f"  Input features keys: {list(features.keys())}")
    for degree, tensor in features.items():
        has_nan = torch.isnan(tensor).any()
        has_inf = torch.isinf(tensor).any()
        print(f"    features['{degree}'].shape: {tensor.shape}, Has NaN: {has_nan}, Has Inf: {has_inf}")

    print(f"\n  Weights keys: {list(self.weights.keys())}")
    for degree, weight in self.weights.items():
         print(f"    self.weights['{degree}'].shape: {weight.shape}")
    print("-" * 30)

    output_features = {}
    # Iterate through the degrees defined in the layer's weights
    for degree, weight in self.weights.items():
        print(f"  Processing degree '{degree}':")

        if degree not in features:
            print(f"    WARNING: Degree '{degree}' in weights but not found in input features dict. Skipping.")
            continue

        feat = features[degree]
        # W: (C_out, C_in)
        # X: (N, D, C_in)
        # Y: (N, D, C_out)
        # Einsum: 'ndi,oi->ndo' where W has shape (C_out, C_in) -> oi

        print(f"      Weight shape (oi): {weight.shape}")
        print(f"      Feature shape (ndi): {feat.shape}")

        # Check compatibility for einsum: weight C_in must match feature C_in
        expected_feat_channels = weight.shape[-1] # C_in from weight
        actual_feat_channels = feat.shape[-1]     # C_in from feature

        if expected_feat_channels != actual_feat_channels:
             print(f"    ERROR: Channel mismatch for einsum! Weight expects C_in={expected_feat_channels}, Feature has C_in={actual_feat_channels}")
             print(f"    Skipping einsum for degree '{degree}'.")
             continue

        try:
            # Apply einsum: N=nodes, D=dim, i=in_channels, o=out_channels
            print(f"    Attempting einsum: torch.einsum('ndi,oi->ndo', features['{degree}'], self.weights['{degree}'])")
            # Note the transpose ('oi') assumes weight is (C_out, C_in)
            result = torch.einsum('ndi,oi->ndo', feat, weight)
            output_features[degree] = result
            print(f"      -> Success! Output shape: {result.shape}")
            # Check for NaN/Inf in output
            has_nan_out = torch.isnan(result).any()
            has_inf_out = torch.isinf(result).any()
            if has_nan_out or has_inf_out:
                 print(f"      WARNING: Output for degree '{degree}' contains NaN: {has_nan_out}, Inf: {has_inf_out}")

        except RuntimeError as e:
            print(f"    !!! RuntimeError during einsum for degree '{degree}' !!!")
            print(f"      Error message: {e}")
            # raise e
            print(f"    Skipping degree '{degree}' after error.")
            continue # Skip to next degree

    print(f"--- End Patched LinearSE3.forward (Einsum) ---\n")
    return output_features


# Apply the patch
LinearSE3.forward = patched_linear_einsum_forward
print("Successfully patched LinearSE3.forward with einsum.")

Attempting to monkeypatch LinearSE3.forward with einsum...
Successfully patched LinearSE3.forward with einsum.


In [48]:
from se3_transformer.runtime.utils import degree_to_dim
import torch
from torch import Tensor
from typing import Dict, Optional


# Import the specific class and potentially dependencies like Fiber
from se3_transformer.model.layers.convolution import VersatileConvSE3, RadialProfile
# You might need other imports if your fixed function uses them

print("Attempting to monkeypatch VersatileConvSE3.forward...")

# 1. Define your new forward function (copy the corrected logic)
def fixed_versatile_forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Optional[Tensor], degree_in: int, degree_out: int):
    """
    Patched forward method for VersatileConvSE3.
    Uses einsum for basis contraction and handles scalar case.
    Includes contiguity checks.
    """
    num_edges, dim_in, channels_in_feat = features.shape # Get C_in from actual features

    # Basic check for empty graph edge case after bidirectionality
    if num_edges == 0:
        print("VersatileConvSE3.forward: Skipping computation for 0 edges.")
        # Need to determine expected output shape based on fiber_out
        # This might require inspecting self.fiber_out which should be available
        out_dim = self.fiber_out.dim
        out_channels = self.channels_out
        return torch.zeros(0, out_channels, out_dim, device=features.device, dtype=features.dtype)


    # with nvtx.range(f'RadialProfile_Patched'):
    # invariant_edge_feats shape is (num_edges, edge_dim)
    # Ensure edge_dim matches what RadialProfile expects
    
    # Access the internal network (assuming it's called 'net')

    radial_output = self.radial_func(invariant_edge_feats)

    if basis is not None:
        # basis_size = basis.shape[1] // dim_in
        basis_size = degree_to_dim(min(degree_in, degree_out))
        # Basis contraction path (e.g., 1->0, 1->1, 0->1, etc.)
        # basis shape is (num_edges, dim_in * basis_size, dim_out)
        print("------DEBUG VERSFORWARD------")
        print(f"features passed in shape: {features.shape}")
        print(f"basis: {basis.shape}")
        print(f"dim in: {dim_in}")
        print(f"        radial_output.shape = {radial_output.shape}")
        expected_radial_numel = num_edges * self.channels_out * channels_in_feat * basis_size
        print(f"        expected_radial_numel = {expected_radial_numel} (based on edges={num_edges}, C_out={self.channels_out}, C_in={channels_in_feat}, F={basis_size})")

        dim_out = basis.shape[-1]

        # Calculate expected radial output size
        # if radial_output.numel() != expected_radial_numel:
        #     raise ValueError(
        #         f"RadialProfile output size mismatch (basis path)! "
        #         f"Expected {expected_radial_numel}, got {radial_output.numel()}. "
        #         f"Shape: {radial_output.shape}. C_out={self.channels_out}, C_in={channels_in_feat}, F={basis_size}"
        #     )

        # Reshape radial_weights for einsum: (E, C_out, C_in, F)
        radial_weights_reshaped = radial_output.view(
            num_edges, self.channels_out, channels_in_feat, basis_size
        ).contiguous() # Add contiguous

        # Reshape basis for einsum: (E, D_in, F, D_out)
        basis_reshaped = basis.view(num_edges, dim_in, basis_size, dim_out).contiguous() # Add contiguous

        # Ensure features are contiguous
        features = features.contiguous()

        # Perform the correct tensor contraction: nli,noif,nlfk->nok
        try:
            out = torch.einsum('nli,noif,nlfk->nok', features, radial_weights_reshaped, basis_reshaped)
        except RuntimeError as e:
             print("ERROR during einsum (basis path)!")
             print(f"  features shape: {features.shape}, contiguous: {features.is_contiguous()}")
             print(f"  radial_weights_reshaped shape: {radial_weights_reshaped.shape}, contiguous: {radial_weights_reshaped.is_contiguous()}")
             print(f"  basis_reshaped shape: {basis_reshaped.shape}, contiguous: {basis_reshaped.is_contiguous()}")
             raise e
        return out

    else:
        # k = l = 0 non-fused case (scalar -> scalar)
        # features shape (E, D_in=1, C_in)
        expected_radial_numel = num_edges * self.channels_out * channels_in_feat * 1 # basis_size=1
        if radial_output.numel() != expected_radial_numel:
            pass #pass FOR NOW
            # raise ValueError(
            #     f"RadialProfile output size mismatch (scalar path)! "
            #     f"Expected {expected_radial_numel}, got {radial_output.numel()}. "
            #     f"Shape: {radial_output.shape}. C_out={self.channels_out}, C_in={channels_in_feat}"
            #  )

        try:
            radial_weights = radial_output.view(num_edges, self.channels_out, channels_in_feat).contiguous() # Line 206 + contiguous
        except RuntimeError as e:
            print("ERROR during .view() (scalar path)!")
            print(f"  radial_output shape: {radial_output.shape}")
            print(f"  Target shape: ({num_edges}, {self.channels_out}, {channels_in_feat})")
            raise e

        features = features.contiguous()

        try:
            # Use channels_in_feat from features for einsum consistency
            # features shape (E, 1, C_in) -> nli
            # radial_weights shape (E, C_out, C_in) -> noi
            result = torch.einsum('nli,noi->no', features, radial_weights).unsqueeze(-1) # Line 210. Output (E, C_out, 1)
        except RuntimeError as e:
             print("ERROR during einsum (scalar path)!")
             print(f"  features shape: {features.shape}, contiguous: {features.is_contiguous()}")
             print(f"  radial_weights shape: {radial_weights.shape}, contiguous: {radial_weights.is_contiguous()}")
             raise e

        return result

# 2. Apply the patch
VersatileConvSE3.forward = fixed_versatile_forward
print("Successfully patched VersatileConvSE3.forward.")

Attempting to monkeypatch VersatileConvSE3.forward...
Successfully patched VersatileConvSE3.forward.


In [49]:
from se3_transformer.model.layers.convolution import RadialProfile

def fixed__init__(
            self,
            num_freq: int,
            channels_in: int,
            channels_out: int,
            edge_dim: int = 1,
            mid_dim: int = 32,
            use_layer_norm: bool = False
    ):
        """
        :param num_freq:         Number of frequencies
        :param channels_in:      Number of input channels
        :param channels_out:     Number of output channels
        :param edge_dim:         Number of invariant edge features (input to the radial function)
        :param mid_dim:          Size of the hidden MLP layers
        :param use_layer_norm:   Apply layer normalization between MLP layers
        """
        super(RadialProfile, self).__init__()
        modules = [
            nn.Linear(edge_dim, mid_dim),
            nn.LayerNorm(mid_dim) if use_layer_norm else None,
            nn.ReLU(),
            nn.Linear(mid_dim, mid_dim),
            nn.LayerNorm(mid_dim) if use_layer_norm else None,
            nn.ReLU(),
            nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False)
        ]
        self.expected_output_size = num_freq * channels_in * channels_out
        self._edge_dim = edge_dim
        self._mid_dim = mid_dim
        self.net = torch.jit.script(nn.Sequential(*[m for m in modules if m is not None]))

RadialProfile.__init__ = fixed__init__

print("Patched RP")

Patched RP


In [50]:
from se3_transformer.model.layers.convolution import RadialProfile # Make sure it's imported

# original_radial_forward = RadialProfile.forward # Store original if needed

def patched_radial_forward(self, features: Tensor) -> Tensor:
    print(f"--- INSIDE RadialProfile.forward ---")
    print(f"    Received features shape: {features.shape}, size: {features.nelement()}")
    # You could even print min/max/mean/isnan to check for bad values
    # print(f"    Features has NaN: {torch.isnan(features).any()}")
    # print(f"    Features min/max: {features.min()}, {features.max()}")
    try:
        result = self.net(features)
        # assert result.nelement() == self.expected_output_size, f"SHAPE MISMATCH IN RADFORWARD: o: {result.nelement()}, e: {self.expected_output_size}\nedge_dim:{self._edge_dim}, mid_dim: {self._mid_dim}"
        print(f"Outputting shape: {result.shape}.")
        return result
    except RuntimeError as e:
        print(f"    !!! ERROR occurred within self.net(features) !!!")
        raise e

RadialProfile.forward = patched_radial_forward
print("Patched RadialProfile.forward with debug prints.")


Patched RadialProfile.forward with debug prints.


In [1]:
import os
!pip install einops
!pip install bitsandbytes
!pip install rna-fm # https://github.com/ml4bio/RNA-FM
!pip install torch_geometric
# !pip install viennarna
!pip install networkx
os.chdir("/notebooks/SE3Transformer")
!pip install -e .
!pip install -r requirements.txt
# !pip install dgl==1.0.0
!pip install --pre dgl -f https://data.dgl.ai/wheels/cu121/repo.html
os.chdir("/notebooks/proj")

print("Completed pip process")

Collecting einops
  Downloading einops-0.8.1-py3-none-any.whl.metadata (13 kB)
Downloading einops-0.8.1-py3-none-any.whl (64 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.4/64.4 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.1
[0mCollecting rna-fm
  Downloading rna_fm-0.2.2-py3-none-any.whl.metadata (10 kB)
Collecting ptflops (from rna-fm)
  Downloading ptflops-0.7.4-py3-none-any.whl.metadata (9.4 kB)
Downloading rna_fm-0.2.2-py3-none-any.whl (46 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.7/46.7 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ptflops-0.7.4-py3-none-any.whl (19 kB)
Installing collected packages: ptflops, rna-fm
Successfully installed ptflops-0.7.4 rna-fm-0.2.2
[0mCollecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63

In [2]:
import sys
sys.path.append("/notebooks/SE3Transformer")
import se3_transformer
sys.path.append("/notebooks/RNAstructure/exe")
import RNAstructure
# sys.path.append("/notebooks/GCNfold")
# from GCNfold import models
# from nets.gcnfold_net import GCNFoldNet_UNet


In [3]:
import torch_geometric

In [5]:
##################################################################################
#----- MIGHT HAVE TO  RUN BLOCK TWICE TO RESOLVE TYPERROR && DO NOT CHANGE -----##
##################################################################################
!pip install torchdata==0.7.0
import warnings
warnings.filterwarnings("ignore")

from sklearn.manifold import MDS
import networkx as nx
import random
import pickle
import yaml
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import time
from sklearn.model_selection import train_test_split
from torch.nn.utils.rnn import pad_sequence
import fm
from sklearn.manifold import MDS
from torch_geometric.data import Data
import dgl
import torch_geometric
# import RNA

from se3_transformer.model.transformer import SE3Transformer
from se3_transformer.model.fiber import Fiber

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

[0m

DGL backend not selected or invalid.  Assuming PyTorch for now.


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)
Using device: cuda


# 1. CONFIG & SEED

In [9]:
def set_seed(seed: int):
    """Set a random seed for Python, NumPy, PyTorch (CPU & GPU) to ensure reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

# Example configuration (you can load this from a YAML, JSON, etc.)
config = {
    "seed": 42,
    "cutoff_date": "2020-01-01",
    "test_cutoff_date": "2022-05-01",
    "max_len": 1024,
    "batch_size": 1,
    "model_config_path": "ribonanzanet2d-final/configs/pairwise.yaml",
    "max_len_filter": 1024,
    "min_len_filter": 10,
    
    "train_sequences_path": "data/Competition/train_sequences.csv",
    "train_labels_path": "data/Competition/train_labels.csv",
    "test_data_path": "data/Competition/test_sequences.csv",
    "combined_train_data_path": "data/Combined/total_processed_rna_data.pt",
    "final_pretrained_weights_path": "weights/RibonanzaNet-3D-final.pt",
    "nonfinal_pretrained_weights_path": "weights/RibonanzaNet-3D.pt",
    "save_weights_name": "weights/RibonanzaNet-3D.pt",
    "save_weights_final": "weights/RibonanzaNet-3D-final.pt",
    "rna_fm_weights": "weights/RNA-FM_pretrained.pth",
    "path_to_GCNFold_weights": "weights/model_unet_99.pth",
    "rna_fm_embedding_dim": 640 # default 640; DO NOT CHANGE
}

# Set the seed for reproducibility
set_seed(config["seed"])

# import shutil
# shutil.copy("/root/.cache/torch/hub/checkpoints/RNA-FM_pretrained.pth", config["rna_fm_weights"])

# 2. DATA LOADING & PREPARATION

In [10]:

# Load CSVs
train_sequences = pd.read_csv(config["train_sequences_path"])
train_labels = pd.read_csv(config["train_labels_path"])

# Create a pdb_id field
train_labels["pdb_id"] = train_labels["ID"].apply(
    lambda x: x.split("_")[0] + "_" + x.split("_")[1]
)

# Collect xyz data for each sequence
all_xyz = []
for pdb_id in tqdm(train_sequences["target_id"], desc="Collecting XYZ data"):
    df = train_labels[train_labels["pdb_id"] == pdb_id]
    xyz = df[["x_1", "y_1", "z_1"]].to_numpy().astype("float32")
    xyz[xyz < -1e17] = float("nan")
    all_xyz.append(xyz)
    


Collecting XYZ data: 100%|██████████| 844/844 [00:05<00:00, 153.74it/s]


# 2.5 SECONDARY DATA (BPPMs, initial 3D structs, initial sequence embeddings, etc.) GENERATION

In [11]:
sys.path.append("ribonanzanet2d-final")

from Network import *

class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)
        self.entries=entries

    def print(self):
        print(self.entries)

def load_config_from_yaml(file_path):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    return Config(**config)

class finetuned_RibonanzaNet(RibonanzaNet):
    def __init__(self, config):
        config.dropout=0.2
        super(finetuned_RibonanzaNet, self).__init__(config)
        self.use_gradient_checkpoint = False
        self.ct_predictor=nn.Linear(64,1)
        self.dropout = nn.Dropout(0.0)
        
    def forward(self,src):
        
        #with torch.no_grad():
        _, pairwise_features=self.get_embeddings(src, torch.ones_like(src).long().to(src.device))

        pairwise_features=pairwise_features+pairwise_features.permute(0,2,1,3) #symmetrize

        output=self.ct_predictor(self.dropout(pairwise_features)) #predict

        return output.squeeze(-1)

ribonet=finetuned_RibonanzaNet(load_config_from_yaml("ribonanzanet2d-final/configs/pairwise.yaml")).cuda()
ribonet.load_state_dict(torch.load("weights/RibonanzaNet-SS.pt",map_location='cpu'))
ribonet.eval()

rna_fmodel, alphabet = fm.pretrained.rna_fm_t12(config["rna_fm_weights"])
rnafm_batch_converter = alphabet.get_batch_converter()
rna_fmodel.eval()

reverse_map = {
    0: "A", 1: "C", 2: "G", 3: "U"
}

def tokens_to_str(tokens):
    tokens = tokens.tolist()
    seq = ""
    for token in tokens:
        seq+=reverse_map[token]
    return seq

def init_coords_from_sequence(
    seq,
    bppm,
    contact_d=6.0,
    noncontact_d=25.0,
    mds_kwargs=None):
    """
    Args:
        seq: RNA sequence str of len L
        bppm: pair prob matrix of (L, L)
        contact_d: target distance (Å) for predicted base pairs
        noncontact_d: target distance (Å) for non-paired nucleotides
        mds_kwargs: extra args for sklearn.manifold.MDS

    Returns:
        coords: tensor of shape (L,3)
    """

    P = bppm
    
    L = P.shape[0]
    
    # 2. Build graph & run MWM
    G = nx.Graph()
    for i in range(L):
        for j in range(i+4, L):  # enforce minimum loop length
            p = P[i, j]
            if p > 0.01:  # skip ultra-low probs
                w = torch.log(p / (1 - p + 1e-9))
                if w > 0:
                    G.add_edge(i, j, weight=w)
    match = nx.algorithms.matching.max_weight_matching(
        G, maxcardinality=False
    )  # O(L³) but usually <0.05 s for L≈400

    # 3. Build a target distance matrix
    D = np.full((L, L), noncontact_d, dtype=float)
    for i, j in match:
        D[i, j] = D[j, i] = contact_d
    np.fill_diagonal(D, 0.0)

    # 4. Run classical MDS to embed into ℝ³
    mds_kwargs = mds_kwargs or {}
    mds = MDS(
        n_components=3,
        dissimilarity="precomputed",
        n_init=4,
        max_iter=300,
        **mds_kwargs
    )
    coords = mds.fit_transform(D)  # (L,3), preserves the “contact” proximities
    return torch.from_numpy(coords).float().cuda()

vocab = {"A":0, "C":1, "G":2, "U":3}
def get_ribonet_bpp(sequence): # tensor of shape (1, L, L)
    src = sequence.unsqueeze(0)
    return ribonet(src).sigmoid().detach().cpu()
    
def get_rnaf_seq_encoding(sequence): 
    # sequence = tokens_to_str(sequence[0]) # CURRENTLY ONLY SUPPORTS BATCH SIZE 1 ### FIX ###
     
    # Prepare data
    data = [
        ("Sequence", sequence)
    ]
    _, _, batch_tokens = rnafm_batch_converter(data) # [(id, seq),...] -> batch label, seq, tokens

    # Extract embeddings (on CPU)
    with torch.no_grad():
        results = rna_fmodel(batch_tokens, repr_layers=[rna_fmodel.num_layers])
    # print(results["representations"])
    token_embeddings = results["representations"][rna_fmodel.num_layers].cuda()
    token_embeddings = token_embeddings[:, 1:-1, :]
    return token_embeddings # (1, seqlen, 640)



constructing 9 ConvTransformerEncoderLayers


In [10]:
# # I want to create literal arrays of initial_embedding, initial_3d, bppm for each sequence
# # print(train_sequences["sequence"].head())

# import csv

# init_seq_embeddings, initial_3ds, bppms = [], [], []
# invalid_indices = []
# def generate_support_data():
#     """
#     Generates all support data and saves to respective arrays. Do not run every time.
#     """
#     total = 0
    
    

    
#     for i, sequence in tqdm(enumerate(train_sequences["sequence"])):
#         if len(sequence) > 1024: # RNA-FM constraint
#             invalid_indices.append(i)
#             init_seq_embeddings.append([])
#             initial_3ds.append([])
#             bppms.append([])
#             total+=1
#             continue
#         emb = get_rnaf_seq_encoding(sequence)
#         init_seq_embeddings.append(emb)
#         bppm = get_ribonet_bpp(sequence)
#         bppms.append(bppm)
#         init3ds = init_coords_from_sequence(sequence, bppm)
#         initial_3ds.append(init3ds)
#         total+=1
#     print(f"Finished processing {i} sequences")

# print(f"Generating support data for {len(train_sequences['sequence'])} sequences...")

# generate_support_data()

# def load_support_data(path):
    
#     pass

# assert not bppms==[], "Must either call load or save support data"



# DATA FILTERING

In [12]:
valid_indices = []
max_len_seen = 0

for i, xyz in enumerate(all_xyz):
    # Track the maximum length
    if len(xyz) > max_len_seen:
        max_len_seen = len(xyz)

    nan_ratio = np.isnan(xyz).mean()
    seq_len = len(xyz)
    # Keep sequence if it meets criteria
    if (nan_ratio <= 0.5) and (config["min_len_filter"] < seq_len <= config["max_len_filter"]):
        valid_indices.append(i)

print(f"Longest sequence in train: {max_len_seen}")

# Filter sequences & xyz based on valid_indices
train_sequences = train_sequences.loc[valid_indices].reset_index(drop=True)
all_xyz = [all_xyz[i] for i in valid_indices]
# init_seq_embeddings = [init_seq_embeddings[i] for i in valid_indices]
# initial_3ds = [initial_3ds[i] for i in valid_indices]
# bppms = [bppms[i] for i in valid_indices]

# Prepare final data dictionary
data = {
    "sequence": train_sequences["sequence"].tolist(),
    "temporal_cutoff": train_sequences["temporal_cutoff"].tolist(),
    "description": train_sequences["description"].tolist(),
    "all_sequences": train_sequences["all_sequences"].tolist(),
    "xyz": all_xyz
    # "base_pair_matrices": bppms,
    # "3d_inits": tertiary_inits,
    # "seq_embedding_inits": seq_emb_inits
}

Longest sequence in train: 4298


# 4. TRAIN / VAL SPLIT

In [13]:
'''
cutoff_date = pd.Timestamp(config["cutoff_date"])
test_cutoff_date = pd.Timestamp(config["test_cutoff_date"])

train_indices = [i for i, date_str in enumerate(data["temporal_cutoff"]) if pd.Timestamp(date_str) <= cutoff_date]
test_indices = [i for i, date_str in enumerate(data["temporal_cutoff"]) if cutoff_date < pd.Timestamp(date_str) <= test_cutoff_date]
'''



all_indices = list(range(len(data["sequence"])))
train_indices, test_indices = train_test_split(all_indices, test_size=0.1, random_state=config["seed"])


# 5. DATASET & DATALOADER

In [14]:
def rna_collate_fn(batch):
    sequences = [item["sequence"] for item in batch]
    xyzs = [item["xyz"] for item in batch]

    # Create masks before padding
    masks = [torch.ones(len(seq), dtype=torch.bool) for seq in sequences]

    # Pad sequences and coordinates
    padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=4)  # 4 = <PAD> token index
    padded_xyzs = pad_sequence(xyzs, batch_first=True, padding_value=0.0)
    padded_masks = pad_sequence(masks, batch_first=True, padding_value=0)

    return {
        "sequence": padded_sequences,
        "xyz": padded_xyzs,
        "mask": padded_masks
    }


class RNA3D_Dataset(Dataset):
    """
    A PyTorch Dataset for 3D RNA structures.
    """
    def __init__(self, indices, data_dict, max_len=384):
        self.indices = indices
        self.data = data_dict
        self.max_len = max_len
        self.nt_to_idx = {nt: i for i, nt in enumerate("ACGU")}
        
    def __len__(self):
        return len(self.indices)
   
    def clean_sequences(self):
        clean_seqs = []
        clean_xyz = []
        clean_indices = []

        for seq, coords in zip(self.data["sequence"], self.data["xyz"]):
            if 'X' in seq or coords is None or len(seq) != len(coords):
                continue
            clean_seqs.append(seq)
            clean_xyz.append(coords)

        self.data["sequence"] = clean_seqs
        self.data["xyz"] = clean_xyz
        self.indices = list(range(len(clean_seqs)))

    def __getitem__(self, idx):
        data_idx = self.indices[idx]
        # Convert nucleotides to integer tokens
        sequence = []

        sequence = [self.nt_to_idx[nt] for nt in self.data["sequence"][data_idx]]
        sequence = torch.tensor(sequence, dtype=torch.long)
        # Convert xyz to torch tensor
        xyz = torch.tensor(self.data["xyz"][data_idx], dtype=torch.float32)

        # If sequence is longer than max_len, randomly crop
        if len(sequence) > self.max_len:
            crop_start = np.random.randint(len(sequence) - self.max_len)
            crop_end = crop_start + self.max_len
            sequence = sequence[crop_start:crop_end]
            xyz = xyz[crop_start:crop_end]

        return {"sequence": sequence, "xyz": xyz}

train_dataset = RNA3D_Dataset(train_indices, data, max_len=config["max_len"])
train_dataset.clean_sequences()
val_dataset = RNA3D_Dataset(test_indices, data, max_len=config["max_len"])
val_dataset.clean_sequences()

train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True,
    num_workers=8,  # Adjust based on CPU cores
    pin_memory=True,
    prefetch_factor=2,
    collate_fn=rna_collate_fn
    )
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False, num_workers=8, pin_memory=True, 
                        collate_fn=rna_collate_fn)

# 6. MODEL, CONFIG CLASSES & HELPER FUNCTIONS

In [17]:
# with torch.no_grad():
#     L = 20
#     dummy_S   = torch.randn(L, 640, device='cuda')
#     dummy_P   = torch.rand  (L, L,  device='cuda')
#     dummy_xyz = torch.randn(L, 3,   device='cuda')
#     g = _make_graph(dummy_S, dummy_P, dummy_xyz,
#                     model.rbf_mu, model.rbf_sigma, thresh=0.2)
#     print("scalar width =", g.edge_feats[0].shape[1])   # must be 32

In [15]:
'''
class SeqPairBlock(nn.Module):
    def __init__(self, seq_dim, pair_dim, n_heads=8, use_triangular_attention=True):
        super().__init__()
        self.qkv = nn.Linear(seq_dim, 3*seq_dim)
        self.p_bias = nn.Linear(pair_dim, n_heads)           # per‑head scalar bias
        self.attn = nn.MultiheadAttention(embed_dim=256, num_heads=8, batch_first=True)
        self.triangle_update_out=TriangleMultiplicativeModule(dim=pair_dimension,mix='outgoing')
        self.triangle_update_in=TriangleMultiplicativeModule(dim=pair_dimension,mix='ingoing')

        self.pair_dropout_out=DropoutRowwise(dropout)
        self.pair_dropout_in=DropoutRowwise(dropout)

        self.use_triangular_attention=use_triangular_attention

        if self.use_triangular_attention:
            self.triangle_attention_out=TriangleAttention(in_dim=pair_dimension,
                                                                    dim=pair_dimension//4,
                                                                    wise='row')
            self.triangle_attention_in=TriangleAttention(in_dim=pair_dimension,
                                                                    dim=pair_dimension//4,
                                                                    wise='col')
            self.pair_attention_dropout_out=DropoutRowwise(dropout)
            self.pair_attention_dropout_in=DropoutColumnwise(dropout)

        self.ffn_seq = nn.Sequential(nn.Linear(seq_dim,4*seq_dim),
                                     nn.GELU(),
                                     nn.Linear(4*seq_dim,seq_dim))
        self.outer_proj = nn.Sequential(nn.Linear(seq_dim*2+seq_dim**2, pair_dim),
                                        nn.ReLU(),
                                        nn.Linear(pair_dim,pair_dim))
        # self.tri_mult = TriangleMul(d_p)                # optional
        # self.tri_att  = TriangleAtt(d_p//2)             # optional

    def forward(self, S, P):
        # 1. self‑att with pair bias
        q,k,v = self.qkv(S).chunk(3,dim=-1)
        bias  = self.p_bias(P).permute(2,0,1)           # (heads,L,L)
        S = S + self.attn(q,k,v, attn_bias=bias)

        # 2. FF on S
        S = S + self.ffn_seq(S)

        # 3. pair update
        op = torch.einsum('id,jd->ijd', S, S)           # outer product
        feats = torch.cat((op, S[:,None]+S[None,:]), dim=-1)
        P = P + self.outer_proj(feats)

        # # 4. triangle refinement every k blocks
        # if do_triangle:
        #     P = P + self.tri_mult(P)
        #     P = P + self.tri_att(P)

        return S, P
'''

class PairEmbedding(nn.Module):
    def __init__(self, d_seq, d_pair, d_hidden=128):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(1, d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, d_pair)
        )
        self.outer_product_mean = Outer_Product_Mean(in_dim=d_seq, pairwise_dim=d_pair)
        self.rel_pos_embed = relpos(dim=d_pair)

    def forward(self, seq_rep, bppm):
        print(f"embedding seq_rep of shape {seq_rep.shape}, and bppm of shape {bppm.shape}")
        x = bppm.unsqueeze(-1)                       # (L,L,1) bppm is len 28
        pair_embed = self.mlp(x)                        # (L,L,d_pair)
        outer_prod_mean = self.outer_product_mean(seq_rep)  # seq_rep is len 30
        rel_embeddings = self.rel_pos_embed(seq_rep)
        print(f"Pair: {pair_embed.shape}, outer: {outer_prod_mean.shape}, relpos: {rel_embeddings.shape}")
        summed_pair_rep = outer_prod_mean + rel_embeddings + pair_embed
        return summed_pair_rep

class ConvFormerBlocks(nn.Module):
    def __init__(self, n_blocks, seq_dim, nhead, pair_dim,
                 use_triangular_attention, dropout):
        super(ConvFormerBlocks, self).__init__()
        self.blocks = nn.ModuleList([
            ConvTransformerEncoderLayer(
                d_model = seq_dim,
                nhead = nhead,
                dim_feedforward = seq_dim*3, 
                pairwise_dimension= pair_dim,
                use_triangular_attention=use_triangular_attention,
                dropout = dropout
            )
            for _ in range(n_blocks)
        ])
    
    def forward(self, seq_embedding, pair_embedding):
        print(f"s: {seq_embedding.shape}, p: {pair_embedding.shape}")
        seqrep = seq_embedding
        pairrep = pair_embedding
        mask = torch.ones(seqrep.size(0), seqrep.size(1), dtype=torch.bool, device=seqrep.device)
        for block in self.blocks:
            seqrep, pairrep = block(seqrep, pairrep, src_mask=mask)
        return seqrep, pairrep

''' PER CLAUDE
class SE3FormerBlocks(nn.Module):
    def __init__(self, n_blocks, seq_dim, thresh):
        super(SE3FormerBlocks, self).__init__()
        self.thresh = thresh
        self.blocks = nn.ModuleList([
            SE3Transformer(
                num_layers     = 4,                    # == 4 equivariant blocks
                num_heads      = 8,                    # matches DeepMind default
                channels_div   = 2,                    # head dim = hidden/2
                fiber_in       = Fiber({0: seq_dim, 1: 1}),
                fiber_hidden   = Fiber({0:128, 1:128, 2:64}),
                fiber_out      = Fiber({1:1}),         # emit coordinate delta
                fiber_edge=Fiber({0:32, 1:1}),
                edge_dim=32,
                use_layer_norm = True,
                self_interaction = True,               # linear on each fibre
                dropout        = 0.1
            )
            for _ in range(n_blocks)
        ])
'''

class SE3FormerBlocks(nn.Module):
    def __init__(self, n_blocks, seq_dim, thresh):
        super(SE3FormerBlocks, self).__init__()
        self.thresh = thresh
        self.blocks = nn.ModuleList([
            SE3Transformer(
                num_layers     = 4,
                num_heads      = 8,
                channels_div   = 2,
                fiber_in       = Fiber({0: seq_dim, 1: 1}), # Input node fibers
                fiber_hidden   = Fiber({0:128, 1:128, 2:64}),
                fiber_out      = Fiber({1:1}),         # Output update vector
                fiber_edge     = Fiber({0:32, 1:1}), # Edge features expected
                edge_dim       = 33, # Expect 32 scalar + 1 norm = 33 total scalar edge dim input for RadialProfile
                use_layer_norm = True,
                self_interaction = True,
                dropout        = 0.1
            )
            for _ in range(n_blocks)
        ])

    def forward(self, seq_rep, bppm, xyz_init, rbf_mu, rbf_sigma, thresh):
        xyz = xyz_init
        # Ensure seq_rep is (L, D) if it comes in as (1, L, D)
        if seq_rep.dim() == 3 and seq_rep.size(0) == 1:
             seq_rep = seq_rep.squeeze(0)

        for block_idx, block in enumerate(self.blocks):
            # Ensure xyz is (L, 3)
            if xyz.dim() == 3 and xyz.size(0) == 1:
                 xyz = xyz.squeeze(0)
            elif xyz.dim() != 2 or xyz.size(1) != 3:
                 raise ValueError(f"xyz shape entering block {block_idx} is wrong: {xyz.shape}")

            # 1. Create PyG Data object containing raw features
            data = _make_graph(seq_rep, bppm, xyz, rbf_mu, rbf_sigma, thresh)

            # 2. Create DGL graph and make bidirected
            src, dst = data.edge_index
            num_nodes = xyz.size(0)
            # --- FIX: Create graph on CPU ---
            g_cpu = dgl.graph((src.cpu(), dst.cpu()), num_nodes=num_nodes)
            # --- Convert to bidirected ON CPU ---
            g_bi_cpu = dgl.to_bidirected(g_cpu)
            # --- Move the bidirected graph to GPU ---
            g = g_bi_cpu.to(device)
            
            # Handle case of graph with no edges after bidirectionality
            if g.num_edges() == 0:
                print(f"Warning: Skipping SE3 block {block_idx} due to 0 edges in DGL graph.")
                continue # Skip to the next block or potentially return xyz if it's the last block

            # 3. Prepare Edge Features for DGL graph (duplicate for bidirected)
            edge_feats_dgl = {}
            for k, v in data.edge_feats.items(): # v is (E_orig, Dim, Channels)
                edge_feats_dgl[k] = torch.cat([v, v], dim=0) # Shape (E_dgl, Dim, Channels)

            # 4. Calculate Relative Positions for DGL graph
            u, v_ = g.edges() # Source and destination indices for DGL edges (E_dgl,)
            rel_pos = xyz[v_] - xyz[u] # Shape (E_dgl, 3)
            g.edata['rel_pos'] = rel_pos

            # 5. Prepare Node Features for DGL message passing context
            #    Fetch features corresponding to *source* nodes (u) and reshape/transpose
            node_feats_dgl = {}
            # Degree 0: Expected shape (E_dgl, 1, 640)
            node_feats_0_orig = data.node_feats['0'][u] # Shape (E_dgl, 640, 1)
            node_feats_dgl['0'] = node_feats_0_orig.transpose(1, 2) # Shape (E_dgl, 1, 640)

            # Degree 1: Expected shape (E_dgl, 3, 1)
            node_feats_1_orig = data.node_feats['1'][u] # Shape (E_dgl, 3, 1)
            node_feats_dgl['1'] = node_feats_1_orig # Already in correct format
            print("PRINTING KEYS::::--------")
            print(f"nodefeats keys: {list(node_feats_dgl.keys())}, edgefeats keys: {list(edge_feats_dgl.keys())}")
            # # --- Assertions for prepared DGL features ---
            # first_layer_fiber_in = self.blocks[block_idx].graph_modules[0].fiber_in
            # assert node_feats_dgl[0].shape == (g.num_edges(), 1, first_layer_fiber_in[0])
            # assert node_feats_dgl[1].shape == (g.num_edges(), 3, first_layer_fiber_in[1])
            # assert edge_feats_dgl[0].shape == (g.num_edges(), 32, 1), f"expected shape {(g.num_edges(), 33, 1)} but got {edge_feats_dgl[0].shape}"
            # assert edge_feats_dgl[1].shape == (g.num_edges(), 3, 1)

            # --- End Assertions ---

            print(f"  Verifying shapes before block call...")
            first_layer_fiber_in = self.blocks[block_idx].graph_modules[0].fiber_in
            assert node_feats_dgl['0'].shape == (g.num_edges(), 1, first_layer_fiber_in[0]), f"Node feat 0 shape mismatch: {node_feats_dgl['0'].shape}"
            assert node_feats_dgl['1'].shape == (g.num_edges(), 3, first_layer_fiber_in[1]), f"Node feat 1 shape mismatch: {node_feats_dgl['1'].shape}"
            assert edge_feats_dgl['0'].shape == (g.num_edges(), 32, 1), f"Edge feat 0 shape mismatch: {edge_feats_dgl['0'].shape}" # Expect 32 now
            assert edge_feats_dgl['1'].shape == (g.num_edges(), 3, 1), f"Edge feat 1 shape mismatch: {edge_feats_dgl['1'].shape}"
            print(f"     Shapes verified.")
            
            # 6. Call the SE3Transformer block
            out_feats = block(
                g,
                node_feats_dgl, # Pass features indexed by source nodes & correctly shaped
                edge_feats=edge_feats_dgl
            )

            # 7. Process Output
            if 1 not in out_feats:
                 raise KeyError(f"SE3Transformer block {block_idx} output missing type 1 features.")

            # Output features are aggregated at destination nodes, shape (L, Dim, Channels)
            xyz_change = out_feats[1].squeeze(-1) # Shape (L, 3)

            if xyz.shape != xyz_change.shape:
                 raise ValueError(f"Shape mismatch after block {block_idx}: xyz ({xyz.shape}) vs xyz_change ({xyz_change.shape})")

        xyz = xyz + xyz_change

        return xyz

''' PER CLAUDE
    def forward(self, seq_rep, bppm, xyz_init, rbf_mu, rbf_sigma, thresh):
        xyz = xyz_init
        
        for block in self.blocks:
            data = _make_graph(seq_rep, bppm, xyz, rbf_mu, rbf_sigma, thresh)
            
            edge_feats = data.edge_feats[0].squeeze(-1)  # (E, D)
            print("edge_feats.shape:", edge_feats.shape)
            # → should be (98, 33) --> correct

            src, dst = data.edge_index   # each is a 1-D tensor of length E
            # num_nodes = data.node_feats[0].shape[0]        # L
            g = dgl.graph((src.tolist(), dst.tolist()))
            g = dgl.to_bidirected(g).to(device)

  
            for k,v in data.edge_feats.items():
                # v was (E, F_k), but g has 2E edges now
                # data.edge_feats[k] = torch.cat([v, v], dim=0)  # now (2E, F_k)
                data.edge_feats[k] = torch.cat([v, v], dim=0)
            
            u, v_ = g.edges()                      # each is (2E,) LongTensor
            rel_pos = xyz[v_] - xyz[u]             # (2E, 3)
            # 3) stash it in g.edata
            g.edata['rel_pos'] = rel_pos
            
            node_feats = {
                "0": data.node_feats[0],     
                "1": data.node_feats[1]      
            }
            edge_feats = {
                "0": data.edge_feats[0],      # (E, #scalar_feats)
                "1": data.edge_feats[1]       # (E, 3, 1)
            }
            assert set(node_feats.keys()) == {"0", "1"},      "node_feats must have exactly keys 0 and 1"
            assert node_feats["0"].shape[1] == 640,       f"expected {640} scalars, got {node_feats[0].shape[1]}"
            assert node_feats["1"].shape[1:] == (3, 1),      "ℓ=1 features must be of shape (L,3,1)"
            # assert set(edge_feats.keys()) == {"0", "1"},      "edge_feats must have exactly keys 0 and 1"
            # assert edge_feats["1"].shape[2] == 1,           "you must provide exactly one vector channel"
            out_feats = block(
                g,
                node_feats,
                edge_feats=edge_feats
            )

            xyz_change = out_feats[1].squeeze(-1)

            xyz = xyz + xyz_change
        return xyz
'''
print("Complete")


Complete


# MODEL INSTANTANTIATION

In [16]:
# SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks

def _make_graph(S, P, xyz, rbf_mu, rbf_sigma, thresh=0.2):
    """
    S  : (L, d_seq)   – updated sequence scalars
    P  : (L, L)       – pair probabilities
    xyz: (L, 3)       – C1′ coordinates from RNAComposer/FARNA
    thresh: float     - threshold for high prob contact classification
    returns PyG Data with edge scalars & vectors ready for SE3-Trf.
    """

    if isinstance(xyz, np.ndarray):
        xyz = torch.from_numpy(xyz).cuda()
    print(xyz.shape)
    # 2) Squeeze off a leading batch dim if present
    #    Now xyz should be exactly (L,3)
    if xyz.dim() == 3 and xyz.size(0) == 1:
        xyz = xyz.squeeze(0)
    elif xyz.dim() == 1 and xyz.numel() == 3:
        raise ValueError("xyz looks like a single point; did you pass the wrong tensor?")
    if S.dim() == 3 and S.size(0) == 1:
        S = S.squeeze(0)       # now (L, d_seq)
    if P.dim() == 3 and P.size(0) == 1:
        P = P.squeeze(0)       # now (L, L)

    L = xyz.size(0)
    src, dst, e_scalar, e_vec = [], [], [], []

    def _add_edge(i, j, etype, pij):
        src.append(i); dst.append(j)

        rel_pos = xyz[j] - xyz[i] # Shape (3,)
        d = torch.norm(rel_pos)

        rbf_feat = rbf(d, rbf_mu, rbf_sigma)  # (30)
        # base_feats = torch.tensor([etype, pij, d], device=P.device)
        # e_scalar.append(torch.cat([base_feats, rbf_feat], dim=0))  # (33,)
        e_scalar.append(torch.cat([torch.tensor([etype, pij], device=P.device),rbf_feat], dim=0))
        
        norm = d + 1e-6 # Add epsilon for stability
        e_vec.append(rel_pos / norm)

    # (a) backbone
    for i in range(L - 1):
        _add_edge(i, i + 1, etype=0, pij=1.0)

    # (b) high-prob contacts
    idx_i, idx_j = torch.where(P > thresh)
    for i, j in zip(idx_i.tolist(), idx_j.tolist()):
        if j <= i + 2:                 # skip tiny loops
            continue
        _add_edge(i, j, etype=1, pij=P[i, j].item())

    # pack tensors
    edge_index = torch.tensor([src, dst], dtype=torch.long, device=P.device)
    e_scalar   = torch.stack(e_scalar, dim=0)                 # (E, 18)
    e_vec      = torch.stack(e_vec, dim=0)                    # (E, 3)
    
    # node features: scalar S_i  (degree-0), vector xyz_i (degree-1)
    node_scalars = S                                           # (L, 3, 1)
    print(f"ns: {node_scalars.shape}, nv: {xyz.shape}, es: {e_scalar.shape}, ev: {e_vec.shape}")
    # SE3-Transformer (Fabian Fuchs) expects dicts keyed by degree
    node_feats = {'0': node_scalars.unsqueeze(-1), '1': xyz.unsqueeze(-1)}
    edge_feats = {'0': e_scalar.unsqueeze(-1), '1': e_vec.unsqueeze(-1)}       # (E, 3, 1)
    
    print("--- INSIDE _make_graph (End) ---")
    print(f"    Output edge_feats keys: {list(edge_feats.keys())}")
    if 0 in edge_feats:
        print(f"    Output edge_feats[0] shape: {edge_feats[0].shape}")
    else:
        print("    Output edge_feats has NO key '0'")
    if 1 in edge_feats:
        print(f"    Output edge_feats[1] shape: {edge_feats[1].shape}")

    data = Data()
    data.edge_index = edge_index
    data.node_feats = node_feats
    data.edge_feats = edge_feats
    print(f"  Output data.edge_index shape: {data.edge_index.shape}")
    return data

def rbf(d, centers, widths):
    """Gaussian radial basis for a distance tensor d (..., 1)."""
    return torch.exp(-((d - centers) ** 2) / (2 * widths ** 2))


class ChocolateNet(nn.Module):
    """
    pretrained_state: either 0, 1, or 2 depending on how weights should be loaded:
    - 0: no pretraining
    - 1: load non-final pretrained weights
    - 2: load final pretrained weights
    """

    def __init__(self, thresh=0.20, pretrained_state=0, dropout=0.1):

        super(ChocolateNet,self).__init__()
        if pretrained_state==2:
            print("loading final pretrained weights...")
            self.load_state_dict(
                torch.load(config["final_pretrained_weights_path"], map_location="cpu"), strict = True
            )
        elif pretrained_state==1:
            print("loading nonfinal pretrained weights...")
            self.load_state_dict(
                torch.load(config["nonfinal_pretrained_weights_path"], map_location="cpu"), strict = True
            )
        elif pretrained_state==0:
            print("initializing fresh model...")
        else:
            raise ValueError("Unknown pretrained_state configuration. See class description.")
        
        self.config = {"gradient_accumulation_steps": 1}
        self.thresh = thresh
        self.seq_dim = config["rna_fm_embedding_dim"]
        self.pair_dim = 128
        self.heads = 8
        
        self.dropout = nn.Dropout(p=dropout)
        
        self.pair_embedding = PairEmbedding(self.seq_dim, self.pair_dim)

        self.sequence_transformer = ConvFormerBlocks(
            n_blocks = 3,
            seq_dim = self.seq_dim, 
            nhead = self.heads, 
            pair_dim = self.pair_dim,
            use_triangular_attention=True,
            dropout = dropout
        )
        
        # (3) RBF parameters for edge-length encoding
        mu = torch.linspace(0, 20, 30)               # 30 Gaussians
        sigma = 0.8 * torch.ones_like(mu)
        self.register_buffer("rbf_mu", mu)
        self.register_buffer("rbf_sigma", sigma)
        
        self.se3_transformer = SE3FormerBlocks(
            n_blocks = 1, seq_dim=self.seq_dim, thresh=self.thresh
        )
        
        
    def forward(self, sequence):
        sequence = sequence[0] # DOES NOT SUPPORT BATCH SIZE > 1, FIX!!
        print(sequence.shape)
        # 1) Get raw RNA-FM embeddings (1, L, d_seq)
        fm_emb = get_rnaf_seq_encoding(sequence).cuda()      # → torch.FloatTensor on CPU

        # 2) Get BPPM from RiboNet, convert to Tensor
        bppm = get_ribonet_bpp(sequence).float().cuda()
        # 3) Now build your pair embedding correctly
        pair_embedding = self.pair_embedding(fm_emb, bppm)      # both use L
        bppm_raw = bppm.squeeze(0)
        
        # # fm_embedding = get_rnaf_seq_encoding(sequence[0])
        # bppm = get_ribonet_bpp(sequence[0])
        # bppm_src = torch.from_numpy(bppm).float().cuda()
        
        
        # pair_embedding = self.pair_embedding(fm_embedding, bppm_src)
        
        xyz_init = init_coords_from_sequence(sequence, bppm_raw)
        seq_rep, pair_rep = self.sequence_transformer(fm_emb, pair_embedding)
        xyz_pred = self.se3_transformer(
                    seq_rep, bppm_raw, xyz_init, self.rbf_mu, self.rbf_sigma, self.thresh
                    )
        
        return xyz_pred
        
# Instantiate the model
model = ChocolateNet().cuda()
print("insted")

initializing fresh model...
DEBUG VersatileConvSE3 Init: edge_dim=33
  chan_in=2, chan_out=32
DEBUG VersatileConvSE3 Init: edge_dim=33
  chan_in=2, chan_out=128
DEBUG VersatileConvSE3 Init: edge_dim=33
  chan_in=2, chan_out=128
DEBUG VersatileConvSE3 Init: edge_dim=33
  chan_in=640, chan_out=32
DEBUG VersatileConvSE3 Init: edge_dim=33
  chan_in=640, chan_out=128
DEBUG VersatileConvSE3 Init: edge_dim=33
  chan_in=640, chan_out=128
DEBUG VersatileConvSE3 Init: edge_dim=33
  chan_in=64, chan_out=64
DEBUG VersatileConvSE3 Init: edge_dim=33
  chan_in=64, chan_out=128
DEBUG VersatileConvSE3 Init: edge_dim=33
  chan_in=64, chan_out=128
DEBUG VersatileConvSE3 Init: edge_dim=33
  chan_in=128, chan_out=64
DEBUG VersatileConvSE3 Init: edge_dim=33
  chan_in=128, chan_out=128
DEBUG VersatileConvSE3 Init: edge_dim=33
  chan_in=128, chan_out=128
DEBUG VersatileConvSE3 Init: edge_dim=33
  chan_in=129, chan_out=64
DEBUG VersatileConvSE3 Init: edge_dim=33
  chan_in=129, chan_out=128
DEBUG VersatileConvS

In [37]:
print(model.se3_transformer.blocks[0].graph_modules)

Sequential(
  (0): AttentionBlockSE3(
    (to_key_value): ConvSE3(
      (conv): ModuleDict(
        (1,2): VersatileConvSE3(
          (radial_func): RadialProfile(
            (net): RecursiveScriptModule(
              original_name=Sequential
              (0): RecursiveScriptModule(original_name=Linear)
              (1): RecursiveScriptModule(original_name=LayerNorm)
              (2): RecursiveScriptModule(original_name=ReLU)
              (3): RecursiveScriptModule(original_name=Linear)
              (4): RecursiveScriptModule(original_name=LayerNorm)
              (5): RecursiveScriptModule(original_name=ReLU)
              (6): RecursiveScriptModule(original_name=Linear)
            )
          )
        )
        (1,0): VersatileConvSE3(
          (radial_func): RadialProfile(
            (net): RecursiveScriptModule(
              original_name=Sequential
              (0): RecursiveScriptModule(original_name=Linear)
              (1): RecursiveScriptModule(original_name=La

# 7. LOSS FUNCTIONS

In [17]:
def calculate_distance_matrix(X, Y, epsilon=1e-4):
    """
    Calculate pairwise distances between every point in X and every point in Y.
    Shape: (len(X), len(Y))
    """
    return ((X[:, None] - Y[None, :])**2 + epsilon).sum(dim=-1).sqrt()

def dRMSD(pred_x, pred_y, gt_x, gt_y, epsilon=1e-4, Z=10, d_clamp=None):
    """
    Distance-based RMSD.
    pred_x, pred_y: predicted coordinates (usually the same tensor for X and Y).
    gt_x, gt_y: ground truth coordinates.
    """
    pred_dm = calculate_distance_matrix(pred_x, pred_y)
    gt_dm = calculate_distance_matrix(gt_x, gt_y)

    mask = ~torch.isnan(gt_dm)
    mask[torch.eye(mask.shape[0], device=mask.device).bool()] = False

    diff_sq = (pred_dm[mask] - gt_dm[mask])**2 + epsilon
    if d_clamp is not None:
        diff_sq = diff_sq.clamp(max=d_clamp**2)

    return diff_sq.sqrt().mean() / Z

def local_dRMSD(pred_x, pred_y, gt_x, gt_y, epsilon=1e-4, Z=10, d_clamp=30):
    """
    Local distance-based RMSD, ignoring distances above a clamp threshold.
    """
    pred_dm = calculate_distance_matrix(pred_x, pred_y)
    gt_dm = calculate_distance_matrix(gt_x, gt_y)

    mask = (~torch.isnan(gt_dm)) & (gt_dm < d_clamp)
    mask[torch.eye(mask.shape[0], device=mask.device).bool()] = False

    diff_sq = (pred_dm[mask] - gt_dm[mask])**2 + epsilon
    return diff_sq.sqrt().mean() / Z

def dRMAE(pred_x, pred_y, gt_x, gt_y, epsilon=1e-4, Z=10):
    """
    Distance-based Mean Absolute Error.
    """
    pred_dm = calculate_distance_matrix(pred_x, pred_y)
    gt_dm = calculate_distance_matrix(gt_x, gt_y)

    mask = ~torch.isnan(gt_dm)
    mask[torch.eye(mask.shape[0], device=mask.device).bool()] = False

    diff = torch.abs(pred_dm[mask] - gt_dm[mask])
    return diff.mean() / Z

def align_svd_mae(input_coords, target_coords, Z=10):
    """
    Align input_coords to target_coords via SVD (Kabsch algorithm) and compute MAE.
    """
    assert input_coords.shape == target_coords.shape, "Input and target must have the same shape"

    # Create mask for valid points
    mask = ~torch.isnan(target_coords.sum(dim=-1))
    input_coords = input_coords[mask]
    target_coords = target_coords[mask]
    
    # Compute centroids
    centroid_input = input_coords.mean(dim=0, keepdim=True)
    centroid_target = target_coords.mean(dim=0, keepdim=True)

    # Center the points
    input_centered = input_coords - centroid_input
    target_centered = target_coords - centroid_target

    # Compute covariance matrix
    cov_matrix = input_centered.T @ target_centered

    # SVD to find optimal rotation
    U, S, Vt = torch.svd(cov_matrix)
    R = Vt @ U.T

    # Ensure a proper rotation (determinant R == 1)
    if torch.det(R) < 0:
        Vt_adj = Vt.clone()   # Clone to avoid in-place modification issues
        Vt_adj[-1, :] = -Vt_adj[-1, :]
        R = Vt_adj @ U.T

    # Rotate input and compute mean absolute error
    aligned_input = (input_centered @ R.T) + centroid_target
    return torch.abs(aligned_input - target_coords).mean() / Z


# 8. TRAINING LOOP

In [18]:
# IMPLEMENT TRAIN() FROM SE3TRANSFORMER

def train_model(model, train_dl, val_dl, epochs=50, cos_epoch=35, lr=3e-4, clip=1):
    """Train the model with a CosineAnnealingLR after `cos_epoch` epochs."""
    optimizer = torch.optim.AdamW(model.parameters(), weight_decay=0.0, lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=(epochs - cos_epoch) * len(train_dl),
    )
    grad_accum_steps = model.config["gradient_accumulation_steps"]
    best_val_loss = float("inf")
    best_preds = None
    
    for epoch in range(epochs):
        model.train()
        train_pbar = tqdm(train_dl, desc=f"Training Epoch {epoch+1}/{epochs}")
        running_loss = 0.0

        # Add profiling for the first few batches of the first epoch
        profiling_enabled = (epoch == 0)

        for idx, batch in enumerate(train_pbar):

            sequence = batch["sequence"].cuda()
            
            gt_xyz = batch["xyz"].squeeze().cuda()
            #mask = batch["mask"].cuda()
            # Only profile the first 5 batches of the first epoch
            if profiling_enabled and idx < 10:
                torch.cuda.synchronize()
                start_forward = time.time()
                
                # Remove autocast
                pred_xyz = model(sequence).squeeze()
                
                torch.cuda.synchronize()
                forward_time = time.time() - start_forward
                
                torch.cuda.synchronize()
                start_loss = time.time()
                
                # Remove autocast
                loss = dRMAE(pred_xyz, pred_xyz, gt_xyz, gt_xyz) + align_svd_mae(pred_xyz, gt_xyz)
                
                torch.cuda.synchronize()
                loss_time = time.time() - start_loss
                
                print(f"Batch {idx}: Forward pass: {forward_time:.4f}s, Loss computation: {loss_time:.4f}s")
                
                # Continue with normal training flow (without scaler)
                
                
            else:
                # Normal non-profiling training code (without autocast and scaler)
                pred_xyz = model(sequence).squeeze()
                loss = dRMAE(pred_xyz, pred_xyz, gt_xyz, gt_xyz) + align_svd_mae(pred_xyz, gt_xyz)
            
            loss = loss / grad_accum_steps
            loss.backward()
            if (idx + 1) % grad_accum_steps == 0 or (idx + 1) == len(train_dl):
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
                optimizer.step()
                optimizer.zero_grad()

                if (epoch + 1) > cos_epoch:
                    scheduler.step()
                            
            running_loss += loss.item()
            avg_loss = running_loss / (idx + 1)
            train_pbar.set_description(f"Epoch {epoch+1} | Loss: {avg_loss:.4f}")
            
        # Validation
        model.eval()
        val_loss = 0.0
        val_preds = []
        with torch.no_grad():
            for idx, batch in enumerate(val_dl):
                sequence = batch["sequence"].cuda()
                gt_xyz = batch["xyz"].squeeze().cuda()
                #mask = batch["mask"].cuda()
                pred_xyz = model(sequence).squeeze()
                loss = dRMAE(pred_xyz, pred_xyz, gt_xyz, gt_xyz)
                val_loss += loss.item()

                val_preds.append((gt_xyz.cpu().numpy(), pred_xyz.cpu().numpy()))

            val_loss /= len(val_dl)
            print(f"Validation Loss (Epoch {epoch+1}): {val_loss:.4f}")

            # Check for improvement
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_preds = val_preds
                torch.save(model.state_dict(), config["save_weights_name"])
                print(f"  -> New best model saved at epoch {epoch+1}")

    # Save final model
    torch.save(model.state_dict(), config["save_weights_final"])
    return best_val_loss, best_preds

# 9. RUN TRAINING

In [22]:
print(f"Configured batch size: {config['batch_size']}")
print(f"Train loader batch size: {train_loader.batch_size}")

Configured batch size: 1
Train loader batch size: 1


In [51]:


if __name__ == "__main__":
    best_loss, best_predictions = train_model(
        model=model,
        train_dl=train_loader,
        val_dl=val_loader,
        epochs=50,         # or config["epochs"]
        cos_epoch=35,      # or config["cos_epoch"]
        lr=3e-4,
        clip=1
    )
    print(f"Best Validation Loss: {best_loss:.4f}")
    

Training Epoch 1/50:   0%|          | 0/731 [00:00<?, ?it/s]

torch.Size([75])
embedding seq_rep of shape torch.Size([1, 75, 640]), and bppm of shape torch.Size([1, 75, 75])
Pair: torch.Size([1, 75, 75, 128]), outer: torch.Size([1, 75, 75, 128]), relpos: torch.Size([1, 75, 75, 128])


Training Epoch 1/50:   0%|          | 0/731 [00:01<?, ?it/s]

s: torch.Size([1, 75, 640]), p: torch.Size([1, 75, 75, 128])
torch.Size([75, 3])
ns: torch.Size([75, 640]), nv: torch.Size([75, 3]), es: torch.Size([99, 32]), ev: torch.Size([99, 3])
--- INSIDE _make_graph (End) ---
    Output edge_feats keys: ['0', '1']
    Output edge_feats has NO key '0'
  Output data.edge_index shape: torch.Size([2, 99])
PRINTING KEYS::::--------
nodefeats keys: ['0', '1'], edgefeats keys: ['0', '1']
  Verifying shapes before block call...
     Shapes verified.

--- Entering SE3 Sequential Block ---
  Input to Module 0 (AttentionBlockSE3):
    features['0'].shape: torch.Size([198, 1, 640])
    features['1'].shape: torch.Size([198, 3, 1])
DEBUG Concat (deg=1): Node Shape torch.Size([198, 3, 1]) + Edge Shape torch.Size([198, 3, 1]) -> Result Shape torch.Size([198, 3, 2])
ConvSE3FuseLevel: ConvSE3FuseLevel.NONE
--- INSIDE RadialProfile.forward ---
    Received features shape: torch.Size([198, 33]), size: 6534
Outputting shape: torch.Size([198, 81920]).
--- INSIDE Radi




DGLError: [04:20:50] /opt/dgl/src/array/./check.h:51: Check failed: gdim[uev_idx[i]] == arrays[i]->shape[0] (75 vs. 198) : Expect E_data to have size 75 on the first dimension, but got 198
Stack trace:
  [bt] (0) /usr/local/lib/python3.11/dist-packages/dgl/libdgl.so(+0x7c6d31) [0x7f02c81c6d31]
  [bt] (1) /usr/local/lib/python3.11/dist-packages/dgl/libdgl.so(dgl::aten::CheckShape(std::vector<unsigned long, std::allocator<unsigned long> > const&, std::vector<int, std::allocator<int> > const&, std::vector<dgl::runtime::NDArray, std::allocator<dgl::runtime::NDArray> > const&, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > const&)+0x35b) [0x7f02c81e9e6b]
  [bt] (2) /usr/local/lib/python3.11/dist-packages/dgl/libdgl.so(+0x7e4aaf) [0x7f02c81e4aaf]
  [bt] (3) /usr/local/lib/python3.11/dist-packages/dgl/libdgl.so(+0x7e4d61) [0x7f02c81e4d61]
  [bt] (4) /usr/local/lib/python3.11/dist-packages/dgl/libdgl.so(DGLFuncCall+0x4c) [0x7f02c8224b9c]
  [bt] (5) /usr/local/lib/python3.11/dist-packages/dgl/_ffi/_cy3/core.cpython-311-x86_64-linux-gnu.so(+0x1ba94) [0x7f033481ba94]
  [bt] (6) /usr/local/lib/python3.11/dist-packages/dgl/_ffi/_cy3/core.cpython-311-x86_64-linux-gnu.so(+0x1bdff) [0x7f033481bdff]
  [bt] (7) /usr/local/bin/python3(_PyObject_MakeTpCall+0x28c) [0x52edac]
  [bt] (8) /usr/local/bin/python3(_PyEval_EvalFrameDefault+0x6bd) [0x53cf5d]



# RUN INFERENCE

In [None]:
# !pip uninstall -y dgl
# !pip install --pre dgl -f https://data.dgl.ai/wheels/cu121/repo.html
torch.cuda.is_available()

In [None]:



test_df = pd.read_csv(config["test_data_path"]) # target_id,sequence,temporal_cutoff,description,all_sequences
print(test_df.head(10))
test_model = FinetunedRibonanzaNet(model_cfg, pretrained_state=2).cuda()
test_model.eval()

submission_rows = []

for idx, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Running inference"):
    seq_id = row["target_id"]
    seq = row["sequence"]
    
    token_map = {'A': 0, 'C': 1, 'U': 2, 'G': 3}
    token_ids = torch.tensor([token_map[c] for c in seq], dtype=torch.long).unsqueeze(0).cuda()  # shape (1, L)
    mask = torch.ones_like(token_ids).cuda()  # or derive if needed

    preds = []
    with torch.no_grad():
        for _ in range(5):  # generate 5 predictions
            pred_xyz = test_model(token_ids, mask).squeeze(0).cpu().numpy()  # shape (L, 3)
            preds.append(pred_xyz)

    preds = np.stack(preds, axis=0)  # shape (5, L, 3)

    for i in range(len(seq)):
        resname = seq[i]
        resid = i + 1
        flat_xyz = preds[:, i, :].flatten()  # (x1,y1,z1,...,x5,y5,z5)
        row = [f"{seq_id}", resname, resid] + flat_xyz.tolist()
        submission_rows.append(row)

# Save to CSV
columns = ["ID", "resname", "resid"] + [f"{axis}_{i+1}" for i in range(5) for axis in ["x", "y", "z"]]
submission = pd.DataFrame(submission_rows, columns=columns)
submission.to_csv("submission.csv", index=False)

print("Inference complete! Saved to submission.csv")