In [14]:
!pip install --no-index /kaggle/input/datasets/kurshidbasheer/biopython-offline/biopython-1.83-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

Processing /kaggle/input/datasets/kurshidbasheer/biopython-offline/biopython-1.83-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
biopython is already installed with the same version as the provided wheel. Use --force-reinstall to force an installation of the wheel.


In [15]:
!pip install --no-index /kaggle/input/datasets/kurshidbasheer/pyg-2-7-torch-2-9-cpu-py312-kur/torch_geometric-2.7.0-py3-none-any.whl

Processing /kaggle/input/datasets/kurshidbasheer/pyg-2-7-torch-2-9-cpu-py312-kur/torch_geometric-2.7.0-py3-none-any.whl
torch-geometric is already installed with the same version as the provided wheel. Use --force-reinstall to force an installation of the wheel.


# Reproducibility

In [None]:
import torch
import numpy as np
import random

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

# Paths

In [None]:
TRAIN_SEQ = "/kaggle/input/stanford-rna-3d-folding-2/train_sequences.csv"
TRAIN_LBL = "/kaggle/input/stanford-rna-3d-folding-2/train_labels.csv"

VAL_SEQ   = "/kaggle/input/stanford-rna-3d-folding-2/validation_sequences.csv"
VAL_LBL   = "/kaggle/input/stanford-rna-3d-folding-2/validation_labels.csv"

TEST_SEQ  = "/kaggle/input/stanford-rna-3d-folding-2/test_sequences.csv"

MSA_DIR   = "/kaggle/input/stanford-rna-3d-folding-2/MSA"
PDB_DIR   = "/kaggle/input/stanford-rna-3d-folding-2/PDB_RNA"
META_PATH = "/kaggle/input/stanford-rna-3d-folding-2/extra/rna_metadata.csv"

# Dataset (Train & Validation)

In [None]:
import pandas as pd
from torch.utils.data import Dataset
from Bio.Seq import Seq

NUC_MAP = {'A':0, 'U':1, 'G':2, 'C':3}

def clean_sequence(seq):
    seq_obj = Seq(seq.upper())
    return "".join([n for n in str(seq_obj) if n in NUC_MAP])

def one_hot(seq):
    x = torch.zeros(len(seq), 4)
    for i, s in enumerate(seq):
        x[i, NUC_MAP[s]] = 1
    return x

class RNADataset(Dataset):

    def __init__(self, seq_csv, label_csv=None, max_length=1000):

        self.q_df = pd.read_csv(seq_csv)
        self.max_length = max_length
        self.has_labels = label_csv is not None

        if self.has_labels:
            labels = pd.read_csv(label_csv, low_memory=False)
            labels["struct_id"] = labels["ID"].str.split("_").str[0]
            labels["res_idx"]   = labels["ID"].str.split("_").str[1].astype(int)

            self.structures = {}

            for k, g in labels.groupby("struct_id"):
                g = g.sort_values("res_idx")
                coords = g[["x_1", "y_1", "z_1"]].values.astype(np.float32)
                self.structures[k] = torch.from_numpy(coords)

            self.valid_ids = [
                sid for sid in self.q_df["target_id"]
                if sid in self.structures
            ]
        else:
            self.valid_ids = list(self.q_df["target_id"])

    def __len__(self):
        return len(self.valid_ids)

    def __getitem__(self, idx):

        sid = self.valid_ids[idx]
        row = self.q_df[self.q_df["target_id"] == sid].iloc[0]

        seq = clean_sequence(row["sequence"])

        if self.has_labels:
            coords = self.structures[sid]
            L = min(len(seq), coords.shape[0])
            seq = seq[:L]
            coords = coords[:L]
        else:
            coords = None
            L = len(seq)

        if L > self.max_length:
            seq = seq[:self.max_length]
            if coords is not None:
                coords = coords[:self.max_length]

        x = one_hot(seq)
        pos_feat = torch.arange(len(seq)).float().unsqueeze(-1) / len(seq)
        x = torch.cat([x, pos_feat], dim=1)

        return sid, x, coords

# Graph Builder

In [None]:
from torch_geometric.data import Data

def center_coordinates(coords):
    center = coords.mean(dim=0, keepdim=True)
    return coords - center

def build_graph(x, coords=None, k=2):

    L = x.size(0)

    edge_index = []
    edge_attr  = []

    if coords is not None:
        coords = center_coordinates(coords)

    for i in range(L):
        for j in range(max(0, i-k), min(L, i+k+1)):
            if i == j:
                continue

            edge_index.append([i, j])

            if coords is not None:
                dist = torch.norm(coords[i] - coords[j])
                edge_attr.append([dist.item()])
            else:
                edge_attr.append([abs(i-j)])

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr  = torch.tensor(edge_attr, dtype=torch.float)

    data = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr
    )

    if coords is not None:
        data.pos = coords
        data.y   = coords

    return data

# Build Train / Validation Graphs

In [None]:
train_dataset = RNADataset(TRAIN_SEQ, TRAIN_LBL)
val_dataset   = RNADataset(VAL_SEQ, VAL_LBL)
test_dataset  = RNADataset(TEST_SEQ, None)

train_graphs = []
val_graphs   = []

for sid, x, coords in train_dataset:
    train_graphs.append(build_graph(x, coords))

for sid, x, coords in val_dataset:
    val_graphs.append(build_graph(x, coords))

# DataLoaders

In [None]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_graphs, batch_size=8, shuffle=True)
val_loader   = DataLoader(val_graphs, batch_size=8, shuffle=False)

# EGNN Model

In [None]:
import torch.nn as nn

class EGNNLayer(nn.Module):
    def __init__(self, feat_dim):
        super().__init__()

        self.edge_mlp = nn.Sequential(
            nn.Linear(2*feat_dim + 1, feat_dim),
            nn.ReLU(),
            nn.Linear(feat_dim, feat_dim)
        )

        self.node_mlp = nn.Sequential(
            nn.Linear(feat_dim, feat_dim),
            nn.ReLU(),
            nn.Linear(feat_dim, feat_dim)
        )

    def forward(self, x, pos, edge_index):

        row, col = edge_index
        diff = pos[row] - pos[col]
        dist2 = torch.sum(diff**2, dim=1, keepdim=True)

        edge_input = torch.cat([x[row], x[col], dist2], dim=1)
        m_ij = self.edge_mlp(edge_input)

        pos_update = diff * m_ij.mean(dim=1, keepdim=True)
        pos = pos + torch.zeros_like(pos).index_add_(0, row, pos_update)

        agg = torch.zeros_like(x).index_add_(0, row, m_ij)
        x = x + self.node_mlp(agg)

        return x, pos


class RNAEGNN(nn.Module):
    def __init__(self, in_channels, hidden_dim=128, layers=3):
        super().__init__()

        self.embedding = nn.Linear(in_channels, hidden_dim)
        self.layers = nn.ModuleList([
            EGNNLayer(hidden_dim) for _ in range(layers)
        ])

    def forward(self, data):

        x = self.embedding(data.x)
        pos = data.pos

        for layer in self.layers:
            x, pos = layer(x, pos, data.edge_index)

        return pos

# Training Setup

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = RNAEGNN(in_channels=train_graphs[0].x.size(1)).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# Metrics

In [None]:
def compute_rmsd(pred, target):
    return torch.sqrt(torch.mean(torch.sum((pred - target)**2, dim=1)))

# train and Validate

In [None]:
def train():
    model.train()
    total_loss = 0

    for batch in train_loader:
        batch = batch.to(device)

        optimizer.zero_grad()
        pred = model(batch)

        loss = criterion(pred, batch.y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)


def evaluate():
    model.eval()
    total_loss = 0
    total_rmsd = 0

    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)

            pred = model(batch)
            loss = criterion(pred, batch.y)
            rmsd = compute_rmsd(pred, batch.y)

            total_loss += loss.item()
            total_rmsd += rmsd.item()

    return total_loss / len(val_loader), total_rmsd / len(val_loader)

# Training Loop

In [None]:
for epoch in range(20):
    train_loss = train()
    val_loss, val_rmsd = evaluate()

    print(f"Epoch {epoch+1:02d} | "
          f"Train Loss: {train_loss:.4f} | "
          f"Val Loss: {val_loss:.4f} | "
          f"Val RMSD: {val_rmsd:.4f}")

# Test Inference

In [None]:
model.eval()
predictions = {}

with torch.no_grad():
    for sid, x, _ in test_dataset:
        graph = build_graph(x, coords=None)
        graph = graph.to(device)

        # For test, we cannot predict coords yet (no pos)
        # Need trained decoder head for inference.
        # Placeholder here.

        predictions[sid] = None