## Step 1: Imports

In [1]:
import pandas as pd
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
import pickle
from tqdm import tqdm

## Step 2: Define model configurations

In [2]:
config = {
    "seed"                        : 0,
    "cutoff_date"                 : "2020-01-01",
    "test_cutoff_date"            : "2022-05-01",
    "max_len"                     : 384,
    "batch_size"                  : 1,
    "learning_rate"               : 1e-4,
    "weight_decay"                : 0.0,
    "mixed_precision"             : "bf16",
    "model_config_path"           : "../working/configs/pairwise.yaml",  # Adjust path as needed
    "epochs"                      : 10,
    "cos_epoch"                   : 5,
    "loss_power_scale"            : 1.0,
    "max_cycles"                  : 1,
    "grad_clip"                   : 0.1,
    "gradient_accumulation_steps" : 1,
    "d_clamp"                     : 30,
    "max_len_filter"              : 9999999,
    "structural_violation_epoch"  : 50,
    "balance_weight"              : False,
}

## Step 3: Create test dataset

In [3]:
test_data = pd.read_csv("/kaggle/input/stanford-rna-3d-folding/test_sequences.csv")
test_data.shape

(12, 5)

In [4]:
from torch.utils.data import Dataset, DataLoader

class RNADataset(Dataset):
    
    def __init__(self,data):
        self.data   = data
        self.tokens = {nt:i for i,nt in enumerate('ACGU')}

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sequence = [self.tokens[nt] for nt in (self.data.loc[idx,'sequence'])]
        sequence = np.array(sequence)
        sequence = torch.tensor(sequence)
        return {'sequence':sequence}

In [5]:
test_dataset = RNADataset(test_data)
test_dataset[0]

{'sequence': tensor([2, 2, 2, 2, 2, 1, 1, 0, 1, 0, 2, 1, 0, 2, 0, 0, 2, 1, 2, 3, 3, 1, 0, 1,
         2, 3, 1, 2, 1, 0, 2, 1, 1, 1, 1, 3, 2, 3, 1, 0, 2, 1, 1, 0, 3, 3, 2, 1,
         0, 1, 3, 1, 1, 2, 2, 1, 3, 2, 1, 2, 0, 0, 3, 3, 1, 3, 2, 1, 3])}

## Step 4: Model Class

In [6]:
import sys
import yaml
sys.path.append("/kaggle/input/ribonanzanet2d-final")
from Network import *

class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)
        self.entries = entries

    def print(self):
        print(self.entries)

def load_config_from_yaml(file_path):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    return Config(**config)

class finetuned_RibonanzaNet(RibonanzaNet):
    def __init__(self, config, pretrained=False):
        super(finetuned_RibonanzaNet, self).__init__(config)
        config.dropout     = 0.2
        self.dropout       = nn.Dropout(0.0)
        self.xyz_predictor = nn.Linear(256,3)
        
        if pretrained:
            self.load_state_dict(torch.load("/kaggle/input/ribonanzanet-weights/RibonanzaNet.pt",map_location='cpu'))

    def forward(self,src):
        sequence_features, pairwise_features = self.get_embeddings(src, torch.ones_like(src).long().to(src.device))
        xyz                                  = self.xyz_predictor(sequence_features)
        return xyz

## Step 5: Load model for inference

In [7]:
model = finetuned_RibonanzaNet(load_config_from_yaml("/kaggle/input/ribonanzanet2d-final/configs/pairwise.yaml"),pretrained=False).cuda()
model.load_state_dict(torch.load("/kaggle/input/ribonanzanet-3d-finetune/RibonanzaNet-3D.pt"))

constructing 9 ConvTransformerEncoderLayers


  model.load_state_dict(torch.load("/kaggle/input/ribonanzanet-3d-finetune/RibonanzaNet-3D.pt"))


<All keys matched successfully>

In [8]:
model

finetuned_RibonanzaNet(
  (transformer_encoder): ModuleList(
    (0-7): 8 x ConvTransformerEncoderLayer(
      (self_attn): MultiHeadAttention(
        (w_qs): Linear(in_features=256, out_features=256, bias=False)
        (w_ks): Linear(in_features=256, out_features=256, bias=False)
        (w_vs): Linear(in_features=256, out_features=256, bias=False)
        (fc): Linear(in_features=256, out_features=256, bias=False)
        (attention): ScaledDotProductAttention(
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (dropout): Dropout(p=0.05, inplace=False)
        (layer_norm): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
      )
      (linear1): Linear(in_features=256, out_features=1024, bias=True)
      (dropout): Dropout(p=0.05, inplace=False)
      (linear2): Linear(in_features=1024, out_features=256, bias=True)
      (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (

## Step 5: Inference on test dataset

In [9]:
model.eval()
preds=[]
for i in tqdm(range(len(test_dataset))):
    src = test_dataset[i]['sequence'].long()
    src = src.unsqueeze(0).cuda()

    model.train()

    tmp=[]
    for i in range(4):
        with torch.no_grad():
            xyz= model(src).squeeze()
        tmp.append(xyz.cpu().numpy())

    model.eval()
    with torch.no_grad():
        xyz =model(src).squeeze()
    tmp.append(xyz.cpu().numpy())

    tmp=np.stack(tmp,0)
    #exit()
    preds.append(tmp)

  return fn(*args, **kwargs)
100%|██████████| 12/12 [00:12<00:00,  1.07s/it]


## Step 6: Plot the graph

In [10]:
import plotly.graph_objects as go
import numpy as np

# Example: Generate an Nx3 matrix

xyz = preds[7][0]  # Replace this with your actual Nx3 data
N = len(xyz)

# Extract columns
x, y, z = xyz[:, 0], xyz[:, 1], xyz[:, 2]

# Create the 3D scatter plot
fig = go.Figure(data=[go.Scatter3d(
    x=x, y=y, z=z,
    mode='markers',
    marker=dict(
        size=5,
        color=z,  # Coloring based on z-value
        colorscale='Viridis',  # Choose a colorscale
        opacity=0.8
    )
)])

# Customize layout
fig.update_layout(
    scene=dict(
        xaxis_title="X",
        yaxis_title="Y",
        zaxis_title="Z"
    ),
    title="3D Scatter Plot"
)

# Show figure
fig.show(renderer='iframe')

## Step 7: Create submission file

In [11]:
ID=[]
resname=[]
resid=[]
x=[]
y=[]
z=[]

data=[]

for i in range(len(test_data)):
    #print(test_data.loc[i])

    
    for j in range(len(test_data.loc[i,'sequence'])):
        # ID.append(test_data.loc[i,'sequence_id']+f"_{j+1}")
        # resname.append(test_data.loc[i,'sequence'][j])
        # resid.append(j+1) # 1 indexed
        row=[test_data.loc[i,'target_id']+f"_{j+1}",
             test_data.loc[i,'sequence'][j],
             j+1]

        for k in range(5):
            for kk in range(3):
                row.append(preds[i][k][j][kk])
        data.append(row)

columns=['ID','resname','resid']
for i in range(1,6):
    columns+=[f"x_{i}"]
    columns+=[f"y_{i}"]
    columns+=[f"z_{i}"]


submission=pd.DataFrame(data,columns=columns)


submission
submission.to_csv('submission.csv',index=False)

In [12]:
submission

Unnamed: 0,ID,resname,resid,x_1,y_1,z_1,x_2,y_2,z_2,x_3,y_3,z_3,x_4,y_4,z_4,x_5,y_5,z_5
0,R1107_1,G,1,27.382818,-5.014665,-7.504237,25.829113,-8.191240,-0.886308,24.197435,-8.640811,3.325861,19.201536,-18.832836,-10.449270,26.907227,-9.046555,-2.278727
1,R1107_2,G,2,28.504259,-2.098067,1.953158,28.461988,-3.984068,4.423258,25.236124,-7.479849,3.677811,26.290943,-4.945357,0.466802,28.874458,-4.260791,4.040070
2,R1107_3,G,3,27.315416,-4.197384,5.914920,29.492739,-2.537778,3.618463,27.098528,-1.815700,6.419380,24.046560,-6.597889,5.548576,30.631533,-1.310431,4.699128
3,R1107_4,G,4,26.636129,-2.189180,5.222723,28.784063,0.898845,7.224222,22.800381,-3.927438,8.077315,26.338175,-2.648475,8.038135,31.214788,1.479454,6.515719
4,R1107_5,G,5,28.003691,2.172521,7.201588,27.461763,1.576155,6.573967,27.156673,-0.474489,2.611463,26.051687,0.396920,7.117808,29.375944,1.194736,5.723046
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2510,R1190_114,U,114,19.666309,19.339981,27.091534,20.518946,18.253843,20.493660,15.961507,18.415779,30.640438,15.169608,18.806190,28.567331,18.383455,18.752460,31.582420
2511,R1190_115,U,115,16.043509,17.524162,29.082409,19.118191,15.794313,20.547287,14.515809,15.247387,31.592960,12.870030,18.579288,30.435932,16.322798,17.701756,32.300236
2512,R1190_116,U,116,15.216601,11.754694,30.492218,19.759642,11.880177,25.161959,13.912585,17.703234,33.970047,13.157947,19.311186,32.844975,15.308034,15.650916,34.541931
2513,R1190_117,U,117,12.137197,14.134659,32.694096,13.904111,13.003989,28.565392,9.205911,13.535951,33.497486,12.722174,19.108322,30.237867,12.111407,16.164152,34.995243
