In [1]:
import pandas as pd
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
import pickle

In [2]:
config = {
    "seed": 0,
    "cutoff_date": "2020-01-01",
    "test_cutoff_date": "2022-05-01",
    "max_len": 384,
    "batch_size": 1,
    "learning_rate": 1e-4,
    "weight_decay": 0.0,
    "mixed_precision": "bf16",
    "model_config_path": "../working/configs/pairwise.yaml",  # Adjust path as needed
    "epochs": 10,
    "cos_epoch": 5,
    "loss_power_scale": 1.0,
    "max_cycles": 1,
    "grad_clip": 0.1,
    "gradient_accumulation_steps": 1,
    "d_clamp": 30,
    "max_len_filter": 9999999,
    "structural_violation_epoch": 50,
    "balance_weight": False,
}

In [3]:
test_data=pd.read_csv("/kaggle/input/stanford-ribonanza-2-rna-folding-in-3-d/test_sequences.csv")

In [4]:
from torch.utils.data import Dataset, DataLoader

class RNADataset(Dataset):
    def __init__(self,data):
        self.data=data
        self.tokens={nt:i for i,nt in enumerate('ACGU')}

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sequence=[self.tokens[nt] for nt in (self.data.loc[idx,'sequence'])]
        sequence=np.array(sequence)
        sequence=torch.tensor(sequence)




        return {'sequence':sequence}

In [5]:
test_dataset=RNADataset(test_data)
test_dataset[0]

{'sequence': tensor([2, 2, 2, 2, 2, 1, 1, 0, 1, 0, 2, 1, 0, 2, 0, 0, 2, 1, 2, 3, 3, 1, 0, 1,
         2, 3, 1, 2, 1, 0, 2, 1, 1, 1, 1, 3, 2, 3, 1, 0, 2, 1, 1, 0, 3, 3, 2, 1,
         0, 1, 3, 1, 1, 2, 2, 1, 3, 2, 1, 2, 0, 0, 3, 3, 1, 3, 2, 1, 3])}

In [6]:
import sys

sys.path.append("/kaggle/input/ribonanzanet2d-final")

import torch.nn as nn
from Network import RibonanzaNet, MultiHeadAttention
import yaml

class SimpleStructureModule(nn.Module):

    def __init__(self, d_model, nhead, 
                 dim_feedforward, pairwise_dimension, dropout=0.1,
                 ):
        super(SimpleStructureModule, self).__init__()
        #self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.self_attn = MultiHeadAttention(d_model, nhead, d_model//nhead, d_model//nhead, dropout=dropout)


        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        #self.norm4 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
        #self.dropout4 = nn.Dropout(dropout)

        self.pairwise2heads=nn.Linear(pairwise_dimension,nhead,bias=False)
        self.pairwise_norm=nn.LayerNorm(pairwise_dimension)

        self.distance2heads=nn.Linear(1,nhead,bias=False)
        #self.pairwise_norm=nn.LayerNorm(pairwise_dimension)

        self.activation = nn.GELU()

        
    def custom(self, module):
        def custom_forward(*inputs):
            inputs = module(*inputs)
            return inputs
        return custom_forward

    def forward(self, input):
        src , pairwise_features, pred_t, src_mask = input
        
        #src = src*src_mask.float().unsqueeze(-1)

        pairwise_bias=self.pairwise2heads(self.pairwise_norm(pairwise_features)).permute(0,3,1,2)

        
        distance_matrix=pred_t[None,:,:]-pred_t[:,None,:]
        distance_matrix=(distance_matrix**2).sum(-1).clip(2,37**2).sqrt()
        distance_matrix=distance_matrix[None,:,:,None]
        distance_bias=self.distance2heads(distance_matrix).permute(0,3,1,2)

                    
        
        pairwise_bias=pairwise_bias+distance_bias

        #print(src.shape)
        src2,attention_weights = self.self_attn(src, src, src, mask=pairwise_bias, src_mask=src_mask)
        

        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)


        return src



class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)
        self.entries=entries

    def print(self):
        print(self.entries)

def load_config_from_yaml(file_path):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    return Config(**config)



class finetuned_RibonanzaNet(RibonanzaNet):
    def __init__(self, config, pretrained=False):
        config.dropout=0.1
        config.use_grad_checkpoint=True
        super(finetuned_RibonanzaNet, self).__init__(config)
        if pretrained:
            self.load_state_dict(torch.load("/kaggle/input/ribonanzanet-weights/RibonanzaNet.pt",map_location='cpu'))
        # self.ct_predictor=nn.Sequential(nn.Linear(64,256),
        #                                 nn.ReLU(),
        #                                 nn.Linear(256,64),
        #                                 nn.ReLU(),
        #                                 nn.Linear(64,1)) 
        self.dropout=nn.Dropout(0.0)

        self.structure_module=SimpleStructureModule(d_model=256, nhead=8, 
                 dim_feedforward=1024, pairwise_dimension=64)
        
        self.xyz_predictor=nn.Linear(256,3)

    def custom(self, module):
        def custom_forward(*inputs):
            inputs = module(*inputs)
            return inputs
        return custom_forward
    
    def forward(self,src):
        
        #with torch.no_grad():
        sequence_features, pairwise_features=self.get_embeddings(src, torch.ones_like(src).long().to(src.device))
        
        xyzs=[]
        xyz=torch.zeros(sequence_features.shape[1],3).cuda().float()
        #print(xyz.shape)
        #xyz=self.xyz_predictor(sequence_features)

        for i in range(18):
            sequence_features=self.structure_module([sequence_features,pairwise_features,xyz,None])
            xyz=xyz+self.xyz_predictor(sequence_features).squeeze(0)
            xyzs.append(xyz)
            
        
        return xyzs

model=finetuned_RibonanzaNet(load_config_from_yaml("/kaggle/input/ribonanzanet2d-final/configs/pairwise.yaml"),pretrained=False).cuda()

constructing 9 ConvTransformerEncoderLayers


In [7]:
model=finetuned_RibonanzaNet(load_config_from_yaml("/kaggle/input/ribonanzanet2d-final/configs/pairwise.yaml"),pretrained=False).cuda()

model.load_state_dict(torch.load("/kaggle/input/ribonanzanet-3d-finetune-add-structure-module/RibonanzaNet-3D.pt"))


constructing 9 ConvTransformerEncoderLayers


  model.load_state_dict(torch.load("/kaggle/input/ribonanzanet-3d-finetune-add-structure-module/RibonanzaNet-3D.pt"))


<All keys matched successfully>

In [8]:
test_dataset[0]['sequence'].shape

torch.Size([69])

In [9]:
model.eval()
preds=[]
for i in range(len(test_dataset)):
    src=test_dataset[i]['sequence'].long()
    src=src.unsqueeze(0).cuda()

    model.train()

    tmp=[]
    for i in range(4):
        with torch.no_grad():
            xyz=model(src)[-1].squeeze()
        tmp.append(xyz.cpu().numpy())

    model.eval()
    with torch.no_grad():
        xyz=model(src)[-1].squeeze()
    tmp.append(xyz.cpu().numpy())

    tmp=np.stack(tmp,0)
    #exit()
    preds.append(tmp)


  return fn(*args, **kwargs)


In [10]:
import plotly.graph_objects as go
import numpy as np

# Example: Generate an Nx3 matrix

xyz = preds[2][0]  # Replace this with your actual Nx3 data
N = len(xyz)

# Extract columns
x, y, z = xyz[:, 0], xyz[:, 1], xyz[:, 2]

# Create the 3D scatter plot
fig = go.Figure(data=[go.Scatter3d(
    x=x, y=y, z=z,
    mode='markers',
    marker=dict(
        size=5,
        color=z,  # Coloring based on z-value
        colorscale='Viridis',  # Choose a colorscale
        opacity=0.8
    )
)])

# Customize layout
fig.update_layout(
    scene=dict(
        xaxis_title="X",
        yaxis_title="Y",
        zaxis_title="Z"
    ),
    title="3D Scatter Plot"
)

# Show figure
fig.show(renderer='iframe')


In [11]:
ID=[]
resname=[]
resid=[]
x=[]
y=[]
z=[]

data=[]

for i in range(len(test_data)):
    #print(test_data.loc[i])

    
    for j in range(len(test_data.loc[i,'sequence'])):
        # ID.append(test_data.loc[i,'sequence_id']+f"_{j+1}")
        # resname.append(test_data.loc[i,'sequence'][j])
        # resid.append(j+1) # 1 indexed
        row=[test_data.loc[i,'target_id']+f"_{j+1}",
             test_data.loc[i,'sequence'][j],
             j+1]

        for k in range(5):
            for kk in range(3):
                row.append(preds[i][k][j][kk])
        data.append(row)

columns=['ID','resname','resid']
for i in range(1,6):
    columns+=[f"x_{i}"]
    columns+=[f"y_{i}"]
    columns+=[f"z_{i}"]


submission=pd.DataFrame(data,columns=columns)


submission
submission.to_csv('submission.csv',index=False)

In [12]:
submission

Unnamed: 0,ID,resname,resid,x_1,y_1,z_1,x_2,y_2,z_2,x_3,y_3,z_3,x_4,y_4,z_4,x_5,y_5,z_5
0,R1107_1,G,1,53.635681,66.559456,-34.423420,60.315742,57.900398,-43.905415,58.696072,59.156673,-40.072990,51.108139,66.942062,-38.840584,54.489334,70.626205,-38.725803
1,R1107_2,G,2,58.287857,65.983574,-34.299332,62.720192,52.816162,-42.148582,64.150169,56.134544,-41.512463,50.994209,63.339298,-41.907356,60.309254,65.389793,-38.050655
2,R1107_3,G,3,67.154060,65.099113,-39.039246,71.705643,52.917595,-44.311378,69.447968,56.764637,-43.238857,52.671638,61.361855,-41.751301,64.993057,61.934792,-42.161385
3,R1107_4,G,4,67.961090,67.289848,-40.953884,76.633507,55.426220,-47.218822,71.402245,60.206440,-43.195770,60.014065,62.185390,-42.888321,66.224281,65.724808,-46.224049
4,R1107_5,G,5,72.410324,70.802437,-45.156467,81.649597,64.363968,-49.616634,76.380188,68.123901,-51.666573,69.972137,66.061195,-47.296932,75.696434,70.249863,-52.228168
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2510,R1190_114,U,114,64.374596,69.868614,-69.597183,73.568245,68.301170,-64.523270,75.348236,66.783424,-77.719467,65.816902,57.605919,-69.379074,77.874931,68.187141,-76.415245
2511,R1190_115,U,115,64.549583,69.228973,-60.463833,76.659203,67.750916,-66.109856,76.949394,66.176109,-73.489342,72.351685,64.128082,-63.564926,79.920288,68.238556,-72.603264
2512,R1190_116,U,116,67.683144,64.909668,-61.671478,80.672112,70.127380,-66.373795,79.745140,64.077118,-73.035843,73.161530,58.984955,-64.417458,83.858116,67.360306,-71.030251
2513,R1190_117,U,117,70.679314,62.761627,-58.306507,87.106422,66.824257,-64.542381,80.278770,59.363712,-70.952736,79.552368,62.909046,-64.199837,87.099052,67.558113,-68.660759
