# A quick GNN baseline

In this notebook, I put together a very basic GNN starter. 
None of the below in necessarily optimal (it's basically what I was able to throw together on a Sunday afternoon), but should be enough to get started for further experimentation. I wrote this on my PC, so this notebook seems to be a little bit bottlenecked by the low amount of CPU cores, but I'm sure it's possible to work around that. The below doesn't actually do very well, but most of the work was the dataloader. Maybe you can do better? 

## Graph neural networks 

[Graph neural networks](https://en.wikipedia.org/wiki/Graph_neural_network) (GNNs) are a class of neural networks that operate on graph data structures. Graphs can be used to describe many objects, such as social networks, road networks, or in this case, molecules. 

Unlike sequences and grids (images) which have strict definitions of adjacency, graphs have a more flexible definition of adjacency, that is defined by the edges of the graph. 

## RNA as graphs

There is no canonical way to represent an RNA molecule as a graph, and I believe the "trick" will be constructing the adjacency. In the case below, the base pairs are simply connected to the neares nodes (defined by `EDGE_DISTANCE`) on each side. If we had the 3D structure, that could be a good value to use for connections as well. From the info we do have, the pairing probablility might also work. 

# PyG

[PyG](https://pyg.org/) (AKA pytorch geometric) is a library that contains all you need to get going with learning on graphs. 

Below I've used the PyG dataset class to create a dataset, and imported a simple GNN, called EdgeCNN, to begin with. 

In [1]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.4.0-py3-none-any.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m21.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch_geometric
Successfully installed torch_geometric-2.4.0


In [2]:
import torch
from torch.utils.data import random_split
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
import pandas as pd
from pathlib import Path
import numpy as np
from sklearn.preprocessing import OneHotEncoder
import polars as pl
import re
from tqdm import tqdm



In [3]:
DATA_DIR = Path("/kaggle/input/stanford-ribonanza-rna-folding/")
TRAIN_CSV = DATA_DIR / "train_data.csv"
TRAIN_PARQUET_FILE = "train_data.parquet"
TEST_CSV = DATA_DIR / "test_sequences.csv"
TEST_PARQUET_FILE = "test_sequences.parquet"
PRED_CSV = "submission.csv"

## Convert to Parquet

Pandas is too slow, so the below converts the training and testing data to parquet. 

In [4]:
def to_parquet(csv_file, parquet_file):
    dummy_df = pl.scan_csv(csv_file)

    new_schema = {}
    for key, value in dummy_df.schema.items():
        if key.startswith("reactivity"):
            new_schema[key] = pl.Float32
        else:
            new_schema[key] = value

    df = pl.scan_csv(csv_file, schema=new_schema)
    
    df.sink_parquet(
            parquet_file,
            compression='uncompressed',
            row_group_size=10,
    )

In [5]:
to_parquet(TRAIN_CSV, TRAIN_PARQUET_FILE)
to_parquet(TEST_CSV, TEST_PARQUET_FILE)

## Define adjacency

There are countless ways you can define adjacency. 
The function below creates an edge connection array that specifies how the edges are connected. 

In this case, an edge is connected to the neighbouring `n` molecules to each side of it, plus optionally itself. 

In [6]:
def nearest_adjacency(sequence_length, n=2, loops=True):
    base = np.arange(sequence_length)
    connections = []
    for i in range(-n, n + 1):
        if i == 0 and not loops:
            continue
        elif i == 0 and loops:
            stack = np.vstack([base, base])
            connections.append(stack)
            continue

        neighbours = base.take(range(i,sequence_length+i), mode='wrap')
        stack = np.vstack([base, neighbours])
        
        if i < 0:
            connections.append(stack[:, -i:])
        elif i > 0:
            connections.append(stack[:, :-i])

    return np.hstack(connections)

In [7]:
EDGE_DISTANCE = 4

## Defining the dataloader. 

The below defines a simple dataloader that parses a parquet file into node embeddings (for now just the one hot encoded bases A, G, U and C), the adjacency (using the function above), and the targets (the reactivity scores). 

In [8]:
class SimpleGraphDataset(Dataset):
    def __init__(self, parquet_name, edge_distance=5, root=None, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        # Set csv name
        self.parquet_name = parquet_name
        # Set edge distance
        self.edge_distance = edge_distance
        # Initialize one hot encoder
        self.node_encoder = OneHotEncoder(sparse_output=False, max_categories=5)
        # For one-hot encoder to possible values
        self.node_encoder.fit(np.array(['A', 'G', 'U', 'C']).reshape(-1, 1))
        # Load dataframe
        self.df = pl.read_parquet(self.parquet_name)
        self.df = self.df.filter(pl.col("SN_filter") == 1.0)
        # Get reactivity columns names
        reactivity_match = re.compile('(reactivity_[0-9])')
        reactivity_names = [col for col in self.df.columns if reactivity_match.match(col)]
        
        self.reactivity_df = self.df.select(reactivity_names) 

        self.sequence_df = self.df.select("sequence")

    def parse_row(self, idx):
        # Read row at idx
        sequence_row = self.sequence_df.row(idx)  
        reactivity_row = self.reactivity_df.row(idx)
        # Get sequence string and convert to array
        sequence = np.array(list(sequence_row[0])).reshape(-1, 1)
        # Encode sequence array
        encoded_sequence = self.node_encoder.transform(sequence)
        # Get sequence length
        sequence_length = len(sequence)
        # Get edge index 
        edges_np = nearest_adjacency(sequence_length, n=self.edge_distance, loops=False)
        # Convert to torch tensor
        edge_index = torch.tensor(edges_np, dtype=torch.long)

        # Get reactivity targets for nodes
        reactivity = np.array(reactivity_row, dtype=np.float32)[0:sequence_length]
     
        # Create valid masks for nodes
        valid_mask = np.argwhere(~np.isnan(reactivity)).reshape(-1)
        torch_valid_mask = torch.tensor(valid_mask, dtype=torch.long)

        # Replace nan values for reactivity with 0. 
        # Not actually super important as they get masked
        reactivity = np.nan_to_num(reactivity, copy=False, nan=0.0)


        # Define node features as one-hot encoded sequence
        node_features = torch.Tensor(encoded_sequence)

        # Targets 
        targets = torch.Tensor(reactivity)

        data = Data(x=node_features, edge_index=edge_index, y=targets, valid_mask=torch_valid_mask)

        return data

    def len(self):
        return len(self.df)

    def get(self, idx):
        data = self.parse_row(idx) 
        return data

## Data handling

Define the train and validation datasets, and load them with dataloaders. 

In [9]:
full_train_dataset = SimpleGraphDataset(parquet_name=TRAIN_PARQUET_FILE, edge_distance=EDGE_DISTANCE)
generator1 = torch.Generator().manual_seed(42)
train_dataset, val_dataset = random_split(full_train_dataset, [0.7, 0.3], generator1)

In [10]:
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=2)

## Loss and metrics

Define the loss function and the MSE, and define the MAE as a loss. 

The target values are clipped to `(0, 1)` as the competition metric is. 

In [11]:
import torch.nn.functional as F

def loss_fn(output, target):
    clipped_target = torch.clip(target, min=0, max=1)
    mses = F.mse_loss(output, clipped_target, reduction='mean')
    return mses

def mae_fn(output, target):
    clipped_target = torch.clip(target, min=0, max=1)
    maes = F.l1_loss(output, clipped_target, reduction='mean')
    return maes

## Define the model

Below we define as simple EdgeCNN from PyG. 
As a start, we give it 128 hidden channels, and 4 layers. 

In [12]:
from torch_geometric.nn.models import EdgeCNN

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = EdgeCNN(in_channels=full_train_dataset.num_features, hidden_channels=128,
                num_layers=4, out_channels=1).to(device)

In [13]:
# Make sure we are using the GPU
device

device(type='cuda')

## Training 

Train the model for 10 epochs. 
Is this a good learning rate? 

In [14]:
n_epochs = 10


optimizer = torch.optim.Adam(model.parameters(), lr=0.0003, weight_decay=5e-4)


for epoch in range(n_epochs):
    train_losses = []
    train_maes = []
    model.train()
    for batch in (pbar := tqdm(train_dataloader, position=0, leave=True)):
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)
        out = torch.squeeze(out)
        loss = loss_fn(out[batch.valid_mask], batch.y[batch.valid_mask])
        mae = mae_fn(out[batch.valid_mask], batch.y[batch.valid_mask])
        loss.backward()
        train_losses.append(loss.detach().cpu().numpy())
        train_maes.append(mae.detach().cpu().numpy())
        optimizer.step()
        pbar.set_description(f"Train loss {loss.detach().cpu().numpy():.4f}")       

    print(f"Epoch {epoch} train loss: ", np.mean(train_losses))    
    print(f"Epoch {epoch} train mae: ", np.mean(train_maes))    
    
    val_losses = []
    val_maes = []
    model.eval()
    for batch in (pbar := tqdm(val_dataloader, position=0, leave=True)):
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)
        out = torch.squeeze(out)
        loss = loss_fn(out[batch.valid_mask], batch.y[batch.valid_mask])
        mae = mae_fn(out[batch.valid_mask], batch.y[batch.valid_mask])
        val_losses.append(loss.detach().cpu().numpy())
        val_maes.append(mae.detach().cpu().numpy())
        pbar.set_description(f"Validation loss {loss.detach().cpu().numpy():.4f}")      
        
    print(f"Epoch {epoch} val loss: ", np.mean(val_losses)) 
    print(f"Epoch {epoch} val mae: ", np.mean(val_maes))



Train loss 0.1383: 100%|██████████| 1199/1199 [04:11<00:00,  4.76it/s]


Epoch 0 train loss:  0.12228444
Epoch 0 train mae:  0.2976345


Validation loss 0.1193: 100%|██████████| 514/514 [01:40<00:00,  5.10it/s]


Epoch 0 val loss:  0.12097534
Epoch 0 val mae:  0.294571


Train loss 0.1241: 100%|██████████| 1199/1199 [04:10<00:00,  4.79it/s]


Epoch 1 train loss:  0.12205741
Epoch 1 train mae:  0.29503754


Validation loss 0.1192: 100%|██████████| 514/514 [01:40<00:00,  5.11it/s]


Epoch 1 val loss:  0.12094678
Epoch 1 val mae:  0.29391912


Train loss 0.1468: 100%|██████████| 1199/1199 [04:13<00:00,  4.73it/s]


Epoch 2 train loss:  0.121927805
Epoch 2 train mae:  0.29464436


Validation loss 0.1192: 100%|██████████| 514/514 [01:44<00:00,  4.93it/s]


Epoch 2 val loss:  0.1209518
Epoch 2 val mae:  0.29405153


Train loss 0.1046: 100%|██████████| 1199/1199 [04:15<00:00,  4.68it/s]


Epoch 3 train loss:  0.12191963
Epoch 3 train mae:  0.29496667


Validation loss 0.1190: 100%|██████████| 514/514 [01:44<00:00,  4.91it/s]


Epoch 3 val loss:  0.12092747
Epoch 3 val mae:  0.2932423


Train loss 0.1411: 100%|██████████| 1199/1199 [04:16<00:00,  4.68it/s]


Epoch 4 train loss:  0.122150816
Epoch 4 train mae:  0.2945868


Validation loss 0.1193: 100%|██████████| 514/514 [01:42<00:00,  4.99it/s]


Epoch 4 val loss:  0.12096829
Epoch 4 val mae:  0.2944289


Train loss 0.1469: 100%|██████████| 1199/1199 [04:14<00:00,  4.71it/s]


Epoch 5 train loss:  0.12214798
Epoch 5 train mae:  0.29503128


Validation loss 0.1193: 100%|██████████| 514/514 [01:42<00:00,  5.03it/s]


Epoch 5 val loss:  0.120979026
Epoch 5 val mae:  0.2946417


Train loss 0.0975: 100%|██████████| 1199/1199 [04:14<00:00,  4.72it/s]


Epoch 6 train loss:  0.12199342
Epoch 6 train mae:  0.29454237


Validation loss 0.1193: 100%|██████████| 514/514 [01:42<00:00,  5.04it/s]


Epoch 6 val loss:  0.12099274
Epoch 6 val mae:  0.29488832


Train loss 0.1173: 100%|██████████| 1199/1199 [04:12<00:00,  4.74it/s]


Epoch 7 train loss:  0.1211497
Epoch 7 train mae:  0.29340014


Validation loss 0.1191: 100%|██████████| 514/514 [01:42<00:00,  5.00it/s]


Epoch 7 val loss:  0.12093639
Epoch 7 val mae:  0.29360175


Train loss 0.1181: 100%|██████████| 1199/1199 [04:14<00:00,  4.71it/s]


Epoch 8 train loss:  0.12180431
Epoch 8 train mae:  0.29447123


Validation loss 0.1192: 100%|██████████| 514/514 [01:42<00:00,  5.01it/s]


Epoch 8 val loss:  0.12095167
Epoch 8 val mae:  0.2940483


Train loss 0.1381: 100%|██████████| 1199/1199 [04:11<00:00,  4.78it/s]


Epoch 9 train loss:  0.121976785
Epoch 9 train mae:  0.29442555


Validation loss 0.1192: 100%|██████████| 514/514 [01:40<00:00,  5.12it/s]

Epoch 9 val loss:  0.1209598
Epoch 9 val mae:  0.29424408





In [15]:
torch.cuda.empty_cache()

## Inference on the test set

Here we define a light weight dataset to handle the inference step. 

In [16]:
class InferenceGraphDataset(Dataset):
    def __init__(self, parquet_name, edge_distance=2, root=None, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        # Set csv name
        self.parquet_name = parquet_name
        # Set edge distance
        self.edge_distance = edge_distance
        # Initialize one hot encoder
        self.node_encoder = OneHotEncoder(sparse_output=False, max_categories=4)
        # For one-hot encoder to possible values
        self.node_encoder.fit(np.array(['A', 'G', 'U', 'C']).reshape(-1, 1))
        # Load dataframe
        self.df = pl.read_parquet(self.parquet_name)

        self.sequence_df = self.df.select("sequence")
        self.id_min_df = self.df.select("id_min")

    def parse_row(self, idx):
        # Read row at idx
        sequence_row = self.sequence_df.row(idx)  
        id_min = self.id_min_df.row(idx)[0]

        # Get sequence string and convert to array
        sequence = np.array(list(sequence_row[0])).reshape(-1, 1)
        # Encode sequence array
        encoded_sequence = self.node_encoder.transform(sequence)
        # Get sequence length
        sequence_length = len(sequence)
        # Get edge index 
        edges_np = nearest_adjacency(sequence_length, n=self.edge_distance, loops=False)
        # Convert to torch tensor
        edge_index = torch.tensor(edges_np, dtype=torch.long)

        # Define node features as one-hot encoded sequence
        node_features = torch.Tensor(encoded_sequence)
        ids = torch.arange(id_min, id_min+sequence_length, 1)

        data = Data(x=node_features, edge_index=edge_index, ids=ids)

        return data

    def len(self):
        return len(self.df)

    def get(self, idx):
        data = self.parse_row(idx) 
        return data

In [17]:
infer_dataset = InferenceGraphDataset(parquet_name=TEST_PARQUET_FILE, edge_distance=EDGE_DISTANCE)
infer_dataloader = DataLoader(infer_dataset, batch_size=128, shuffle=False, num_workers=2)

In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.eval().to(device)

EdgeCNN(4, 1, num_layers=4)

In [19]:
ids = np.empty(shape = (0,1),dtype=int)
preds = np.empty(shape = (0,1),dtype=np.float32)


for batch in tqdm(infer_dataloader):
    batch = batch.to(device)
    out = model(batch.x, batch.edge_index).detach().cpu().numpy()

    ids = np.append(ids, batch.ids.detach().cpu().numpy())
    preds = np.append(preds, out)



100%|██████████| 10499/10499 [1:39:32<00:00,  1.76it/s]


## Submission

Create a csv file with the submission values. 
As you can see, I don't currently distinguish between `DMS_MaP` and `2A3_MaP`, so just write the same value to both. 

In [20]:
submission_df = pl.DataFrame({"id": ids, "reactivity_DMS_MaP": preds, "reactivity_2A3_MaP": preds})

In [21]:
submission_df.write_csv(PRED_CSV)

## Conclusion

This is a basic GNN "starter kit" that does the basics. 
There are many things that can be improved, but I hope this helps people get started. 