## Step 1 : Imports 

In [1]:
import pandas as pd
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
import pickle
from tqdm import tqdm

## Step 2: Configs

In [2]:
config = {
    "seed": 0,
    "cutoff_date": "2020-01-01",
    "test_cutoff_date": "2022-05-01",
    "max_len": 384,
    "batch_size": 1,
    "learning_rate": 1e-4,
    "weight_decay": 0.0,
    "mixed_precision": "bf16",
    "model_config_path": "../working/configs/pairwise.yaml",  # Adjust path as needed
    "epochs": 10,
    "cos_epoch": 5,
    "loss_power_scale": 1.0,
    "max_cycles": 1,
    "grad_clip": 0.1,
    "gradient_accumulation_steps": 1,
    "d_clamp": 30,
    "max_len_filter": 9999999,
    "structural_violation_epoch": 50,
    "balance_weight": False,
}

## Step 3: Load data

In [3]:
train_sequences = pd.read_csv("/kaggle/input/stanford-ribonanza-2-rna-folding-in-3-d/train_sequences.csv")
train_labels = pd.read_csv("/kaggle/input/stanford-ribonanza-2-rna-folding-in-3-d/train_labels.csv")

train_sequences.columns, train_labels.columns

(Index(['target_id', 'sequence', 'temporal_cutoff', 'description',
        'all_sequences'],
       dtype='object'),
 Index(['ID', 'resname', 'resid', 'x_1', 'y_1', 'z_1'], dtype='object'))

In [4]:
train_labels["pdb_id"] = train_labels["ID"].apply(lambda x: x.split("_")[0]+'_'+x.split("_")[1])
train_labels["pdb_id"] 

0         1SCL_A
1         1SCL_A
2         1SCL_A
3         1SCL_A
4         1SCL_A
           ...  
137090    8Z1F_T
137091    8Z1F_T
137092    8Z1F_T
137093    8Z1F_T
137094    8Z1F_T
Name: pdb_id, Length: 137095, dtype: object

In [5]:
all_xyz=[]

for pdb_id in tqdm(train_sequences['target_id']):
    df = train_labels[train_labels["pdb_id"]==pdb_id]
    xyz=df[['x_1','y_1','z_1']].to_numpy().astype('float32')
    xyz[xyz<-1e17]=float('Nan');
    all_xyz.append(xyz)

  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xy

In [6]:
data={
      "sequence":train_sequences['sequence'].to_list(),
      "xyz": all_xyz
}

In [7]:
train_index = train_sequences.index
print(f"Test size: {len(train_index)}")

Test size: 844


## Step 4: Dataset Class

In [8]:
from torch.utils.data import Dataset, DataLoader
from ast import literal_eval

def get_ct(bp,s):
    ct_matrix=np.zeros((len(s),len(s)))
    for b in bp:
        ct_matrix[b[0]-1,b[1]-1]=1
    return ct_matrix

class RNA3D_Dataset(Dataset):
    def __init__(self,indices,data):
        self.indices  = indices
        self.data     = data
        self.tokens   = {nt:i for i,nt in enumerate('ACGU')}

    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        idx      = self.indices[idx]
        sequence = [self.tokens[nt] for nt in (self.data['sequence'][idx]) if nt in self.tokens]
        sequence = np.array(sequence)
        sequence = torch.tensor(sequence)

        #get C1' xyz
        xyz      = self.data['xyz'][idx]
        xyz      = torch.tensor(np.array(xyz))


        if len(sequence)>config['max_len']:
            crop_start  = np.random.randint(len(sequence)-config['max_len'])
            crop_end    = crop_start+config['max_len']
            sequence    = sequence[crop_start:crop_end]
            xyz         = xyz[crop_start:crop_end]

        return {'sequence' : sequence, 'xyz' : xyz}

In [9]:
test_dataset=RNA3D_Dataset(train_index, data)
test_dataset[0]

{'sequence': tensor([2, 2, 2, 3, 2, 1, 3, 1, 0, 2, 3, 0, 1, 2, 0, 2, 0, 2, 2, 0, 0, 1, 1, 2,
         1, 0, 1, 1, 1]),
 'xyz': tensor([[ 13.7600, -25.9740,   0.1020],
         [  9.3100, -29.6380,   2.6690],
         [  5.5290, -27.8130,   5.8780],
         [  2.6780, -24.9010,   9.7930],
         [  1.8270, -20.1360,  11.7930],
         [  2.0400, -14.9080,  11.7710],
         [  1.1070, -11.5130,   7.5170],
         [  2.9910,  -6.4060,   4.7830],
         [  0.8960,  -1.1930,   7.6080],
         [  0.2280,   2.6460,   9.1280],
         [  4.3290,   2.7180,   4.8040],
         [  5.1650,   4.7920,  -0.9140],
         [  2.6100,   9.4950,  -2.3080],
         [  1.1740,  13.8290,   0.2010],
         [  1.5800,  20.1150,   3.7600],
         [ -1.5750,  16.9280,   5.8970],
         [ -6.0510,  14.7620,   5.2240],
         [ -5.5540,  10.4150,   4.3090],
         [ -3.1070,   6.4050,   2.1200],
         [ -1.4100,   3.3350,  -2.6550],
         [  1.8660,  -0.7160,  -4.3330],
         [  3

## Step 5: Load model

In [10]:
import sys

sys.path.append("/kaggle/input/ribonanzanet2d-final")


from Network import *
import yaml



class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)
        self.entries=entries

    def print(self):
        print(self.entries)

def load_config_from_yaml(file_path):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    return Config(**config)

class finetuned_RibonanzaNet(RibonanzaNet):
    def __init__(self, config, pretrained=False):
        config.dropout=0.2
        super(finetuned_RibonanzaNet, self).__init__(config)
        if pretrained:
            self.load_state_dict(torch.load("/kaggle/input/ribonanzanet-weights/RibonanzaNet.pt",map_location='cpu'))
        self.dropout=nn.Dropout(0.0)
        self.xyz_predictor=nn.Linear(256,3)

    def forward(self,src):
        sequence_features, pairwise_features=self.get_embeddings(src, torch.ones_like(src).long().to(src.device))

        xyz=self.xyz_predictor(sequence_features)

        return xyz

In [11]:
model=finetuned_RibonanzaNet(load_config_from_yaml("/kaggle/input/ribonanzanet2d-final/configs/pairwise.yaml"),pretrained=False).cuda()

model.load_state_dict(torch.load("/kaggle/input/ribonanzanet-3d-finetune/RibonanzaNet-3D.pt"))


constructing 9 ConvTransformerEncoderLayers


  model.load_state_dict(torch.load("/kaggle/input/ribonanzanet-3d-finetune/RibonanzaNet-3D.pt"))


<All keys matched successfully>

In [12]:
test_dataset[0]['sequence'].shape

torch.Size([29])

## Step 6: Model Inference

In [13]:
from tqdm import tqdm

model.eval()
preds=[]
for i in tqdm(range(len(test_dataset))):
    src=test_dataset[i]['sequence'].long()
    src=src.unsqueeze(0).cuda()

    model.train()

    tmp=[]
    for i in range(4):
        with torch.no_grad():
            xyz=model(src).squeeze()
        tmp.append(xyz.cpu().numpy())

    model.eval()
    with torch.no_grad():
        xyz=model(src).squeeze()
    tmp.append(xyz.cpu().numpy())

    tmp=np.stack(tmp,0)
    #exit()
    preds.append(tmp)

  return fn(*args, **kwargs)
100%|██████████| 844/844 [04:01<00:00,  3.50it/s]


In [14]:
import plotly.graph_objects as go
import numpy as np

# Example: Generate an Nx3 matrix

xyz = preds[2][1]  # Replace this with your actual Nx3 data
N = len(xyz)

# Extract columns
x, y, z = xyz[:, 0], xyz[:, 1], xyz[:, 2]

# Create the 3D scatter plot
fig = go.Figure(data=[go.Scatter3d(
    x=x, y=y, z=z,
    mode='markers',
    marker=dict(
        size=5,
        color=z,  # Coloring based on z-value
        colorscale='Viridis',  # Choose a colorscale
        opacity=0.8
    )
)])

# Customize layout
fig.update_layout(
    scene=dict(
        xaxis_title="X",
        yaxis_title="Y",
        zaxis_title="Z"
    ),
    title="3D Scatter Plot"
)

# Show figure
fig.show(renderer='iframe')

In [15]:
tmp.shape

(5, 86, 3)

In [16]:
preds[0]

array([[[ 7.33692312e+00,  1.67382355e+01,  4.83677387e+00],
        [ 9.51375961e+00,  1.51120358e+01,  4.14232016e+00],
        [ 1.23284683e+01,  1.06005154e+01,  3.29774737e+00],
        [ 1.21508303e+01,  4.39496422e+00,  2.79672027e+00],
        [ 1.11909971e+01,  1.31413722e+00,  3.45936567e-01],
        [ 7.06607628e+00, -8.35438371e-01,  5.89075804e-01],
        [ 2.60309315e+00, -4.58466816e+00, -1.30309594e+00],
        [ 1.52563319e-01, -3.48567247e+00, -2.88089252e+00],
        [-5.22892714e+00, -5.82414484e+00, -4.61940336e+00],
        [-4.80920506e+00, -8.44968605e+00, -6.02586889e+00],
        [-8.97599125e+00, -3.93101645e+00, -5.30110788e+00],
        [-1.32758398e+01, -2.49770951e+00, -7.48438883e+00],
        [-1.55210123e+01, -6.79638743e-01, -1.38572454e+01],
        [-1.70961227e+01, -4.73178291e+00, -1.34341812e+01],
        [-2.06844826e+01, -9.03652668e+00, -1.75919456e+01],
        [-1.73351440e+01, -1.14871922e+01, -1.52317257e+01],
        [-1.26567125e+01

In [17]:
preds[7][0].shape

(32, 3)

In [18]:
length = 0
length_greater_than_384 = 0
length_greater_than_384_cnt = 0

for idx, row in train_sequences.iterrows():
    current_length = len(row["sequence"])
    length = length + current_length

    if current_length > 384:
        length_greater_than_384 += current_length-384
        length_greater_than_384_cnt += 1

length, length_greater_than_384, length-length_greater_than_384, length_greater_than_384_cnt

(137095, 72651, 64444, 46)

In [19]:
ID=[]
resname=[]
resid=[]
x=[]
y=[]
z=[]

data=[]

for i in tqdm(range(len(train_sequences))):
    try:
        for j in range(len(train_sequences.loc[i,'sequence'])):
            row=[train_sequences.loc[i,'target_id']+f"_{j+1}", 
                 train_sequences.loc[i,'sequence'][j],
                 j+1]
    
            for k in range(5):
                for kk in range(3):
                    row.append(preds[i][k][j][kk])
            data.append(row)
    except Exception as e:
        print(e)

columns=['ID','resname','resid']
for i in range(1,6):
    columns+=[f"x_{i}"]
    columns+=[f"y_{i}"]
    columns+=[f"z_{i}"]


submission=pd.DataFrame(data,columns=columns)


submission
submission.to_csv('submission.csv',index=False)

 36%|███▌      | 300/844 [00:00<00:00, 1157.09it/s]

index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384


 50%|████▉     | 421/844 [00:00<00:00, 933.27it/s] 

index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384


 62%|██████▏   | 520/844 [00:00<00:00, 493.42it/s]

index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384


 77%|███████▋  | 653/844 [00:01<00:00, 463.68it/s]

index 384 is out of bounds for axis 0 with size 384
index 101 is out of bounds for axis 0 with size 101
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 47 is out of bounds for axis 0 with size 47
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 80 is out of bounds for axis 0 with size 80


 90%|█████████ | 762/844 [00:01<00:00, 454.41it/s]

index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384
index 384 is out of bounds for axis 0 with size 384


100%|██████████| 844/844 [00:01<00:00, 546.03it/s]


index 145 is out of bounds for axis 0 with size 145
index 384 is out of bounds for axis 0 with size 384


In [20]:
submission

Unnamed: 0,ID,resname,resid,x_1,y_1,z_1,x_2,y_2,z_2,x_3,y_3,z_3,x_4,y_4,z_4,x_5,y_5,z_5
0,1SCL_A_1,G,1,7.336923,16.738235,4.836774,6.323206,17.031868,6.877194,6.957625,19.433933,6.664508,5.929893,17.375671,7.380263,6.374305,19.093266,9.947103
1,1SCL_A_2,G,2,9.513760,15.112036,4.142320,8.508277,14.534754,4.998576,9.561449,14.671907,3.889475,9.783124,13.672105,8.113719,11.371324,16.423269,9.648827
2,1SCL_A_3,G,3,12.328468,10.600515,3.297747,13.055398,11.272874,2.629486,10.887737,12.156885,3.995100,11.687041,9.420231,4.592191,14.047790,12.028687,7.881827
3,1SCL_A_4,U,4,12.150830,4.394964,2.796720,12.677704,7.871842,0.635010,13.864484,6.878683,2.772907,10.056153,4.366991,3.333331,13.485300,7.117090,6.590899
4,1SCL_A_5,G,5,11.190997,1.314137,0.345937,10.649619,0.725469,0.999374,11.672052,0.940869,1.747799,8.596684,1.105153,0.975656,11.748466,2.713398,4.368589
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
64433,8Z1F_T_82,U,82,0.026509,31.825565,12.412617,5.286232,41.681656,30.244347,10.015234,39.762413,33.105328,11.494699,37.702106,20.402130,1.482342,43.018700,26.166098
64434,8Z1F_T_83,C,83,-4.547946,28.844946,9.258884,1.353285,38.043369,29.382427,4.806351,37.216377,24.306120,12.338634,42.154572,29.318447,1.607505,43.395157,26.519369
64435,8Z1F_T_84,A,84,3.868995,38.780537,22.346367,10.735499,40.702148,30.869923,10.250741,40.913345,28.519682,10.295520,37.302853,27.617889,1.085683,41.920780,23.554651
64436,8Z1F_T_85,U,85,0.583747,35.567463,16.640093,8.062943,39.324131,28.094059,9.664435,38.202061,27.783079,6.892276,36.668190,18.942503,2.878160,41.831444,21.701105


In [21]:
submission["flag"] = False

for idx, row in submission.iterrows():
    if row["x_1"] == 0.0 and row["y_1"] == 0.0 and row["z_1"] == 0.0:
        submission.loc[idx, "flag"] = True

In [22]:
train_labels.rename(columns={"x_1":"x_1_true","y_1":"y_1_true","z_1":"z_1_true" }, inplace=True)

In [23]:
merged_df = submission.merge(train_labels[["ID", "x_1_true", "y_1_true", "z_1_true"]], on=["ID"])

In [24]:
merged_df


invalid value encountered in greater


invalid value encountered in less


invalid value encountered in greater


invalid value encountered in greater


invalid value encountered in less


invalid value encountered in greater



Unnamed: 0,ID,resname,resid,x_1,y_1,z_1,x_2,y_2,z_2,x_3,...,x_4,y_4,z_4,x_5,y_5,z_5,flag,x_1_true,y_1_true,z_1_true
0,1SCL_A_1,G,1,7.336923,16.738235,4.836774,6.323206,17.031868,6.877194,6.957625,...,5.929893,17.375671,7.380263,6.374305,19.093266,9.947103,False,13.760,-25.974001,0.102
1,1SCL_A_2,G,2,9.513760,15.112036,4.142320,8.508277,14.534754,4.998576,9.561449,...,9.783124,13.672105,8.113719,11.371324,16.423269,9.648827,False,9.310,-29.638000,2.669
2,1SCL_A_3,G,3,12.328468,10.600515,3.297747,13.055398,11.272874,2.629486,10.887737,...,11.687041,9.420231,4.592191,14.047790,12.028687,7.881827,False,5.529,-27.813000,5.878
3,1SCL_A_4,U,4,12.150830,4.394964,2.796720,12.677704,7.871842,0.635010,13.864484,...,10.056153,4.366991,3.333331,13.485300,7.117090,6.590899,False,2.678,-24.900999,9.793
4,1SCL_A_5,G,5,11.190997,1.314137,0.345937,10.649619,0.725469,0.999374,11.672052,...,8.596684,1.105153,0.975656,11.748466,2.713398,4.368589,False,1.827,-20.136000,11.793
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
64433,8Z1F_T_82,U,82,0.026509,31.825565,12.412617,5.286232,41.681656,30.244347,10.015234,...,11.494699,37.702106,20.402130,1.482342,43.018700,26.166098,False,,,
64434,8Z1F_T_83,C,83,-4.547946,28.844946,9.258884,1.353285,38.043369,29.382427,4.806351,...,12.338634,42.154572,29.318447,1.607505,43.395157,26.519369,False,,,
64435,8Z1F_T_84,A,84,3.868995,38.780537,22.346367,10.735499,40.702148,30.869923,10.250741,...,10.295520,37.302853,27.617889,1.085683,41.920780,23.554651,False,,,
64436,8Z1F_T_85,U,85,0.583747,35.567463,16.640093,8.062943,39.324131,28.094059,9.664435,...,6.892276,36.668190,18.942503,2.878160,41.831444,21.701105,False,,,


In [25]:
final_df = merged_df[~merged_df["x_1_true"].isna()]

In [26]:
final_df


Unnamed: 0,ID,resname,resid,x_1,y_1,z_1,x_2,y_2,z_2,x_3,...,x_4,y_4,z_4,x_5,y_5,z_5,flag,x_1_true,y_1_true,z_1_true
0,1SCL_A_1,G,1,7.336923,16.738235,4.836774,6.323206,17.031868,6.877194,6.957625,...,5.929893,17.375671,7.380263,6.374305,19.093266,9.947103,False,13.760000,-25.974001,0.102000
1,1SCL_A_2,G,2,9.513760,15.112036,4.142320,8.508277,14.534754,4.998576,9.561449,...,9.783124,13.672105,8.113719,11.371324,16.423269,9.648827,False,9.310000,-29.638000,2.669000
2,1SCL_A_3,G,3,12.328468,10.600515,3.297747,13.055398,11.272874,2.629486,10.887737,...,11.687041,9.420231,4.592191,14.047790,12.028687,7.881827,False,5.529000,-27.813000,5.878000
3,1SCL_A_4,U,4,12.150830,4.394964,2.796720,12.677704,7.871842,0.635010,13.864484,...,10.056153,4.366991,3.333331,13.485300,7.117090,6.590899,False,2.678000,-24.900999,9.793000
4,1SCL_A_5,G,5,11.190997,1.314137,0.345937,10.649619,0.725469,0.999374,11.672052,...,8.596684,1.105153,0.975656,11.748466,2.713398,4.368589,False,1.827000,-20.136000,11.793000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
64413,8Z1F_T_62,U,62,9.595567,7.404656,13.769807,9.993127,4.528863,10.242212,9.730031,...,5.358267,2.095401,11.791045,11.311334,6.068553,10.731156,False,112.516998,117.880997,119.245003
64414,8Z1F_T_63,A,63,12.535778,8.934927,14.085622,11.642522,6.080662,12.158369,11.145070,...,9.600068,-2.015361,14.205014,15.480549,6.695855,9.949566,False,115.292999,116.571999,114.827003
64415,8Z1F_T_64,C,64,16.335852,9.249647,12.175087,16.647243,8.480448,10.581591,13.585906,...,11.012373,1.640385,14.024781,18.642389,8.138820,9.252250,False,115.857002,114.595001,109.509003
64416,8Z1F_T_65,C,65,17.428249,10.387369,11.426862,20.692896,12.840609,11.874450,17.572828,...,16.521351,7.169189,12.243554,21.481298,12.728823,8.446400,False,113.816002,113.236000,104.339996


In [27]:
final_df.to_csv("ribonanzanet1_public_oof.csv")