In [19]:
!pip install torch_geometric
!pip install ViennaRNA



In [20]:
import torch
from torch.utils.data import random_split
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
import pandas as pd
from pathlib import Path
import numpy as np
from sklearn.preprocessing import OneHotEncoder
import polars as pl
import re
from tqdm import tqdm
import networkx as nx


In [21]:
DATA_DIR = Path("/kaggle/input/stanford-ribonanza-rna-folding/")
TRAIN_CSV = DATA_DIR / "train_data_QUICK_START.csv"
TRAIN_PARQUET_FILE = "train_data.parquet"
TEST_CSV = DATA_DIR / "test_sequences.csv"
TEST_PARQUET_FILE = "test_sequences.parquet"
PRED_CSV = "submission.csv"

PARQUET_DATA_DIR = Path("/kaggle/input/parquet-files/")
TRAIN_PARQUET_DIR = PARQUET_DATA_DIR / "train_data.parquet"

In [22]:
def to_parquet(csv_file, parquet_file):
    dummy_df = pl.scan_csv(csv_file)

    new_schema = {}
    for key, value in dummy_df.schema.items():
        if key.startswith("reactivity"):
            new_schema[key] = pl.Float32
        else:
            new_schema[key] = value

    df = pl.scan_csv(csv_file, schema=new_schema)
    
    df.sink_parquet(
            parquet_file,
            compression='uncompressed',
            row_group_size=10,
    )

In [23]:
# to_parquet(TRAIN_CSV, TRAIN_PARQUET_FILE)
# to_parquet(TEST_CSV, TEST_PARQUET_FILE)

## Define adjacency

There are countless ways you can define adjacency. 
The function below creates an edge connection array that specifies how the edges are connected. 

In this case, an edge is connected to the neighbouring `n` molecules to each side of it, plus optionally itself. 

In [24]:
def nearest_adjacency(sequence_length, n=2, loops=True):
    base = np.arange(sequence_length)
    connections = []
    for i in range(-n, n + 1):
        if i == 0 and not loops:
            continue
        elif i == 0 and loops:
            stack = np.vstack([base, base])
            connections.append(stack)
            continue

        neighbours = base.take(range(i,sequence_length+i), mode='wrap')
        stack = np.vstack([base, neighbours])
        
        if i < 0:
            connections.append(stack[:, -i:])
        elif i > 0:
            connections.append(stack[:, :-i])

    return np.hstack(connections)

In [25]:
from ViennaRNA import RNA

import RNA
import numpy as np

def rna_adjacency_matrix(sequence):
    """
    Creates an adjacency matrix for an RNA sequence based on its predicted secondary structure.
    
    Parameters:
    sequence (str): The RNA sequence.
    
    Returns:
    numpy.ndarray: An adjacency matrix representing the base pairs in the RNA structure.
    """
    # Predicting the secondary structure
    ss, _ = RNA.fold(sequence)

    sequence_length = len(sequence)
    matrix = np.zeros((sequence_length, sequence_length), dtype=int)

    # Stack to hold the indices of open brackets
    stack = []

    for i, char in enumerate(ss):
        if char == '(':
            stack.append(i)
        elif char == ')':
            j = stack.pop()
            matrix[i][j] = 1
            matrix[j][i] = 1

    return matrix

def get_rna_structure_edges(sequence):
    # Predict secondary structure
    ss, _ = RNA.fold(sequence)
    edges = []
    stack = []
    for i, char in enumerate(ss):
        if char == '(':
            stack.append(i)
        elif char == ')':
            j = stack.pop()
            edges.append([i, j])
    return np.array(edges).T




In [26]:
EDGE_DISTANCE = 4

## Defining the dataloader. 

The below defines a simple dataloader that parses a parquet file into node embeddings (for now just the one hot encoded bases A, G, U and C), the adjacency (using the function above), and the targets (the reactivity scores). 

In [27]:
class SimpleGraphDataset(Dataset):
    def __init__(self, edge_distance=5, root=None, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        # Set edge distance
        self.edge_distance = edge_distance
        
        
        # Initialize one hot encoder
        self.node_encoder = OneHotEncoder(sparse_output=False, max_categories=5)
        # For one-hot encoder to possible values
        self.node_encoder.fit(np.array(['A', 'G', 'U', 'C']).reshape(-1, 1))
        
        # Load dataframe
        self.df = pl.read_parquet(TRAIN_PARQUET_DIR)
        num_rows = int(len(self.df) * 0.1)
        self.df = self.df[:num_rows]

        # Filter dataframe
        self.df = self.df.filter(pl.col('experiment_type') == '2A3_MaP')
        if "SN_filter" in self.df.columns:
            self.df = self.df.filter(pl.col("SN_filter") == 1.0)
        
        # Get reactivity columns names
        reactivity_match = re.compile('(reactivity_[0-9])')
        reactivity_names = [col for col in self.df.columns if reactivity_match.match(col)]
        self.reactivity_df = self.df.select(reactivity_names) 

        self.sequence_df = self.df.select("sequence")
        
        
    def calculate_additional_feature(self, sequence):
        return
    
    
    def calculate_centrality(self, edge_index):
        # Convert edge index to a NetworkX graph to compute centrality
        G = nx.Graph()
        # Edge index is assumed to be a 2xN array where each column is an edge
        G.add_edges_from(edge_index.T.tolist())
        # Compute degree centrality (returns a dictionary)
        centrality = nx.degree_centrality(G)
        # Convert centrality values to a list in the order of nodes
        centrality_values = [centrality[node] for node in range(len(G))]
        return np.array(centrality_values).reshape(-1, 1)

    def parse_row(self, idx):
        # Read row at idx
        sequence_row = self.sequence_df.row(idx)  
        reactivity_row = self.reactivity_df.row(idx)
        
        # Get sequence string and convert to array
        sequence = np.array(list(sequence_row[0])).reshape(-1, 1)
        # Encode sequence array
        encoded_sequence = self.node_encoder.transform(sequence)
        # Get sequence length
        sequence_length = len(sequence)
        
        
        # Get edge index from nearest adjacency
        nearest_edges = nearest_adjacency(sequence_length, n=self.edge_distance, loops=False)

        # Get edge index from RNA secondary structure
        structure_edges = get_rna_structure_edges(sequence_row[0])

        # Combine both sets of edges
        combined_edges = np.hstack([nearest_edges, structure_edges])

        # Convert to torch tensor
        edge_index = torch.tensor(combined_edges, dtype=torch.long)

        # Get edge index 
#         edges_np = nearest_adjacency(sequence_length, n=self.edge_distance, loops=False)
#         # Convert to torch tensor
#         edge_index = torch.tensor(edges_np, dtype=torch.long)

        
        # Get reactivity targets for nodes
        reactivity = np.array(reactivity_row, dtype=np.float32)[0:sequence_length]
     
        # Create valid masks for nodes
        valid_mask = np.argwhere(~np.isnan(reactivity)).reshape(-1)
        torch_valid_mask = torch.tensor(valid_mask, dtype=torch.long)

        # Replace nan values for reactivity with 0. 
        # Not actually super important as they get masked
        reactivity = np.nan_to_num(reactivity, copy=False, nan=0.0)


        centrality_feature = self.calculate_centrality(edge_index)
        combined_features = np.concatenate([encoded_sequence, centrality_feature], axis=1)
        # Define node features as one-hot encoded sequence
        node_features = torch.Tensor(combined_features)


        # Targets 
        targets = torch.Tensor(reactivity)

        data = Data(x=node_features, edge_index=edge_index, y=targets, valid_mask=torch_valid_mask)

        return data

    def len(self):
        return len(self.df)

    def get(self, idx):
        data = self.parse_row(idx)
        return data

## Data handling

Define the train, validation, and test datasets, and load them with dataloaders. 

In [28]:
full_train_dataset = SimpleGraphDataset(edge_distance=EDGE_DISTANCE)
generator1 = torch.Generator().manual_seed(42)
train_dataset, val_dataset = random_split(full_train_dataset, [0.7, 0.3], generator1)
val_dataset, test_dataset = random_split(val_dataset, [0.7, 0.3], generator1)


In [29]:
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

## Loss and metrics

Define the loss function and the MSE, and define the MAE as a loss. 

The target values are clipped to `(0, 1)` as the competition metric is. 

In [30]:
import torch.nn.functional as F

def loss_fn(output, target):
    clipped_target = torch.clip(target, min=0, max=1)
    mses = F.mse_loss(output, clipped_target, reduction='mean')
    return mses

def mae_fn(output, target):
    clipped_target = torch.clip(target, min=0, max=1)
    maes = F.l1_loss(output, clipped_target, reduction='mean')
    return maes

## Define the model

Below we define as simple EdgeCNN from PyG. 
As a start, we give it 128 hidden channels, and 4 layers. 

In [31]:
from torch_geometric.nn.models import EdgeCNN

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = EdgeCNN(in_channels=full_train_dataset.num_features, hidden_channels=128,
                num_layers=4, out_channels=1,dropout=0.5).to(device)
print(device)

cuda


## Training 


In [32]:
n_epochs = 10

optimizer = torch.optim.Adam(model.parameters(), lr=0.0003, weight_decay=5e-4)

for epoch in range(n_epochs):
    train_losses = []
    train_maes = []
    model.train()
    for batch in (pbar := tqdm(train_dataloader, position=0, leave=True)):
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)
        out = torch.squeeze(out)
        loss = loss_fn(out[batch.valid_mask], batch.y[batch.valid_mask])
        mae = mae_fn(out[batch.valid_mask], batch.y[batch.valid_mask])
        loss.backward()
        train_losses.append(loss.detach().cpu().numpy())
        train_maes.append(mae.detach().cpu().numpy())
        optimizer.step()
        pbar.set_description(f"Train loss {loss.detach().cpu().numpy():.4f}")       

    print(f"Epoch {epoch} train loss: ", np.mean(train_losses))    
    print(f"Epoch {epoch} train mae: ", np.mean(train_maes))    
    
    val_losses = []
    val_maes = []
    model.eval()
    for batch in (pbar := tqdm(val_dataloader, position=0, leave=True)):
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)
        out = torch.squeeze(out)
        loss = loss_fn(out[batch.valid_mask], batch.y[batch.valid_mask])
        mae = mae_fn(out[batch.valid_mask], batch.y[batch.valid_mask])
        val_losses.append(loss.detach().cpu().numpy())
        val_maes.append(mae.detach().cpu().numpy())
        pbar.set_description(f"Validation loss {loss.detach().cpu().numpy():.4f}")
        
    print(f"Epoch {epoch} val loss: ", np.mean(val_losses))
    print(f"Epoch {epoch} val mae: ", np.mean(val_maes))



Train loss 0.0983: 100%|██████████| 46/46 [02:44<00:00,  3.57s/it]


Epoch 0 train loss:  0.12950659
Epoch 0 train mae:  0.32078654


Validation loss 0.1329: 100%|██████████| 14/14 [00:48<00:00,  3.49s/it]


Epoch 0 val loss:  0.110138215
Epoch 0 val mae:  0.28558743


Train loss 0.1449: 100%|██████████| 46/46 [02:44<00:00,  3.57s/it]


Epoch 1 train loss:  0.11688521
Epoch 1 train mae:  0.28868455


Validation loss 0.1299: 100%|██████████| 14/14 [00:49<00:00,  3.56s/it]


Epoch 1 val loss:  0.10511852
Epoch 1 val mae:  0.27563912


Train loss 0.0859: 100%|██████████| 46/46 [02:43<00:00,  3.56s/it]


Epoch 2 train loss:  0.11312051
Epoch 2 train mae:  0.28394493


Validation loss 0.1294: 100%|██████████| 14/14 [00:49<00:00,  3.52s/it]


Epoch 2 val loss:  0.10444168
Epoch 2 val mae:  0.27286857


Train loss 0.1017: 100%|██████████| 46/46 [02:43<00:00,  3.56s/it]


Epoch 3 train loss:  0.115035266
Epoch 3 train mae:  0.2865838


Validation loss 0.1299: 100%|██████████| 14/14 [00:49<00:00,  3.55s/it]


Epoch 3 val loss:  0.103378706
Epoch 3 val mae:  0.26628202


Train loss 0.1146: 100%|██████████| 46/46 [02:44<00:00,  3.58s/it]


Epoch 4 train loss:  0.11279132
Epoch 4 train mae:  0.28228074


Validation loss 0.1292: 100%|██████████| 14/14 [00:49<00:00,  3.51s/it]


Epoch 4 val loss:  0.103542194
Epoch 4 val mae:  0.26828864


Train loss 0.1314: 100%|██████████| 46/46 [02:43<00:00,  3.56s/it]


Epoch 5 train loss:  0.11432973
Epoch 5 train mae:  0.28485757


Validation loss 0.1290: 100%|██████████| 14/14 [00:48<00:00,  3.46s/it]


Epoch 5 val loss:  0.10361055
Epoch 5 val mae:  0.26808384


Train loss 0.1204: 100%|██████████| 46/46 [02:43<00:00,  3.56s/it]


Epoch 6 train loss:  0.115758024
Epoch 6 train mae:  0.28435472


Validation loss 0.1294: 100%|██████████| 14/14 [00:49<00:00,  3.52s/it]


Epoch 6 val loss:  0.10378784
Epoch 6 val mae:  0.2696515


Train loss 0.1264: 100%|██████████| 46/46 [02:43<00:00,  3.55s/it]


Epoch 7 train loss:  0.11210272
Epoch 7 train mae:  0.2797796


Validation loss 0.1296: 100%|██████████| 14/14 [00:49<00:00,  3.51s/it]


Epoch 7 val loss:  0.10456257
Epoch 7 val mae:  0.27223137


Train loss 0.1150: 100%|██████████| 46/46 [02:43<00:00,  3.55s/it]


Epoch 8 train loss:  0.11202128
Epoch 8 train mae:  0.2819653


Validation loss 0.1299: 100%|██████████| 14/14 [00:49<00:00,  3.51s/it]


Epoch 8 val loss:  0.103532806
Epoch 8 val mae:  0.26751897


Train loss 0.0954: 100%|██████████| 46/46 [02:43<00:00,  3.54s/it]


Epoch 9 train loss:  0.11595975
Epoch 9 train mae:  0.28522447


Validation loss 0.1301: 100%|██████████| 14/14 [00:49<00:00,  3.52s/it]

Epoch 9 val loss:  0.10503002
Epoch 9 val mae:  0.27549264





In [33]:
torch.cuda.empty_cache()

### Check model performance

In [34]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.eval().to(device)

EdgeCNN(5, 1, num_layers=4)

In [35]:
test_maes = []
model.eval()  # Set the model to evaluation mode
with torch.no_grad():  # Disable gradient computation
    for batch in tqdm(test_dataloader, position=0, leave=True):
        batch = batch.to(device)  # Move the batch to the appropriate device
        out = model(batch.x, batch.edge_index)  # Forward pass
        out = torch.squeeze(out)
        mae = mae_fn(out[batch.valid_mask], batch.y[batch.valid_mask])  # Calculate MAE
        test_maes.append(mae.detach().cpu().numpy())  # Store MAE

# Calculate and print average loss and MAE for the test dataset
print(f"Test MAE: {np.mean(test_maes):.4f}")




100%|██████████| 12/12 [00:21<00:00,  1.83s/it]

Test MAE: 0.2861



