In [None]:
# mRNA-FM Tutorial

### Workflow of our tutorial

**Preparation**
1. install the RNA-FM package
2. load the necessary libraries

**Task 1. RNA family clustering**

Goal: to demonstrate that RNA-FM embeddings are biologically meaningful

1. read RNA sequences for each family from FASTA files
2. generate the RNA-FM embeddings for each sequence
3. t-SNE dimension reduction on the generated embeddings
4. plot the embeddings in the 2D space

**Task 2. RNA type classification**

Goal: to demonstrate how to use RNA-FM for downstream applications

1. read RNA sequences for each type from a FASTA file
2. generate RNA-FM embeddings for each sequence
3. build the dataset and model
4. train and validate the model
5. test the model on a dataset excluded from training

## Install RNA-FM

In [None]:
!pip install rna-fm
#!pip install -U numpy
!pip install biopython

If pip install fails to install the required packages, we can also uncomment the following cell to install it from source.

In [None]:
# !git clone https://github.com/ml4bio/RNA-FM.git

# !pwd
# !ls
# %cd /content/RNA-FM
# !python setup.py install

In [None]:
import fm  # for development with RNA-FM

from pathlib import Path
import glob

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import numpy as np
import math

from Bio import SeqIO  # for file parsing

from sklearn.manifold import TSNE  # for dimension reduction

from sklearn.model_selection import train_test_split  # for splitting train/val/test

from tqdm.notebook import tqdm  # for showing progress

import matplotlib.pyplot as plt
import pandas as pd

In [None]:
!git clone https://github.com/Sanofi-Public/CodonBERT.git

In [None]:
!head CodonBERT/benchmarks/CodonBERT/data/fine-tune/E.Coli_proteins.csv

In [None]:
 if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

print(f'using {device} device')

data_file = 'CodonBERT/benchmarks/CodonBERT/data/fine-tune/E.Coli_proteins.csv'

## Task. mRNA expression

### Load the pretrained model

In [None]:
# !gdown 1zflX5hHTxuwqcZm6A1npq7ubP8m7LdNX

In [None]:
# Load mRNA-FM model
fm_model, alphabet = fm.pretrained.mrna_fm_t12()   #Path(data_dir, 'RNA-FM_pretrained.pth'))
batch_converter = alphabet.get_batch_converter()

fm_model.to(device)  # use GPU if available

fm_model.eval()  # disables dropout for deterministic results

### Load Model

You don't need to download it again if you have already done so for the previous task.

In [None]:
# !gdown 1zflX5hHTxuwqcZm6A1npq7ubP8m7LdNX  # for Colab only

In [None]:
# Load mRNA-FM model
fm_model, alphabet = fm.pretrained.mrna_fm_t12()   # rna_fm_t12(Path(data_dir, 'RNA-FM_pretrained.pth'))
batch_converter = alphabet.get_batch_converter()

fm_model.to(device)

fm_model.eval()  # disables dropout for deterministic results

### Load data

In [None]:
# load sequences and labels
data_df = pd.read_csv(data_file)
data_df = data_df[data_df["Value"].isin([0, 2])]
display(data_df)
display(data_df.groupby("Split")["Value"].value_counts())

raw_seqs = []
labels = []
for index, row in data_df.iterrows():
  raw_seq = (str(index), row["Sequence"])
  raw_seqs.append(raw_seq)
  labels.append(row["Value"])

In [None]:
# process binary labels (0: low expression; 1: high expression)
labels = np.array(labels)
labels = (labels == 2) * 1
print(labels.shape)

### Extract embedding

In [None]:
chunk_size = 1

# pre-allocate the space to save memory
token_embeddings = np.zeros((len(labels), 1280))

# divide all the sequences into chunks for processing due to the GPU memory limit
for i in tqdm(range(0, len(raw_seqs), chunk_size)):
    data = raw_seqs[i:i+chunk_size]

    batch_labels, batch_strs, batch_tokens = batch_converter(data)

    # use GPU
    with torch.no_grad():
        results = fm_model(batch_tokens.to(device), repr_layers=[12])

    emb = results["representations"][12].cpu().numpy()[:, 0,: ]

    token_embeddings[i:i+chunk_size, :] = emb


print(token_embeddings.shape)

### Construct the dataset and classifier

In [None]:
class RNATypeDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = embeddings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # use the cls token of the mRNA-FM embedding
        return self.embeddings[idx], self.labels[idx]

In [None]:
class RNATypeClassifier(nn.Module):
    def __init__(self, in_dim, num_class):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_class)

    def forward(self, x):
        x = self.fc(x)

        return x

In [None]:
# dataset split
valid_value = data_df["Value"].values != 1
train_list = (data_df["Split"].values == "train") & valid_value
x_train = token_embeddings[train_list]
y_train = labels[train_list]
val_list = (data_df["Split"].values == "val") & valid_value
x_val = token_embeddings[val_list]
y_val = labels[val_list]
test_list = (data_df["Split"].values == "test") & valid_value
x_test = token_embeddings[test_list]
y_test = labels[test_list]

In [None]:
# hyper-parameters

batch_size = 4
lr = 1e-3
epochs = 100

In [None]:
train_dataset = RNATypeDataset(x_train, y_train)
val_dataset = RNATypeDataset(x_val, y_val)
test_dataset = RNATypeDataset(x_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
num_class = 2
in_dim = 1280
model = RNATypeClassifier(in_dim, num_class).to(device)
print(model)

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

### Train the model

In [None]:
max_val_acc = -1
best_epoch = -1

train_loss_history = []
val_loss_history = []

train_acc_history = []
val_acc_history = []

for epoch in tqdm(range(epochs)):

    # train the model
    train_losses = []
    train_preds = []
    train_targets = []

    model.train()

    for batch in train_loader:
        x, y = batch
        x, y = x.to(device).float(), y.to(device).long()

        # no need to apply the softmax function since it has been included in the loss function
        y_pred = model(x)

        # y_pred: (B, C) with class probabilities, y shape: (B,) with class indices
        loss = criterion(y_pred, y)

        train_losses.append(loss.item())
        train_preds.append(torch.max(y_pred.detach(),1)[1])
        train_targets.append(y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # validate the model
    val_losses = []
    val_preds = []
    val_targets = []

    model.eval()

    for batch in val_loader:
        x, y = batch
        x, y = x.to(device).float(), y.to(device).long()

        y_pred = model(x)

        # y_pred: (B, C) with class probabilities, y shape: (B,) with class indices
        loss = criterion(y_pred, y)

        val_losses.append(loss.item())
        val_preds.append(torch.max(y_pred.detach(),1)[1])
        val_targets.append(y)

    # calculate the accuracy
    train_preds = torch.cat(train_preds, dim=0)
    train_targets = torch.cat(train_targets, dim=0)
    train_acc = (train_preds == train_targets).float().mean().cpu()

    val_preds = torch.cat(val_preds, dim=0)
    val_targets = torch.cat(val_targets, dim=0)
    val_acc = (val_preds == val_targets).float().mean().cpu()

    train_acc_history.append(train_acc)
    val_acc_history.append(val_acc)

    # save the model checkpoint for the best validation accuracy
    if val_acc > max_val_acc:
        torch.save({'model_state_dict': model.state_dict()}, 'rna_type_checkpoint.pt')
        best_epoch = epoch
        max_val_acc = val_acc

    # show intermediate steps
    if epoch % 20 == 1:
        tqdm.write(f'epoch {epoch}/{epochs}: train loss={np.mean(train_loss_history):.6f}, '
                   f'train acc={train_acc:.6f}, '
                   f'val loss={np.mean(val_loss_history):.6f}, '
                   f'val acc={val_acc:.6f}')

    train_loss_history.append(np.mean(train_losses))
    val_loss_history.append(np.mean(val_losses))

### Visualize training results

In [None]:
plt.figure(figsize=(8, 6))

plt.plot(train_loss_history, label='train loss')
plt.plot(val_loss_history, label='val loss')

# the epoch with best validation loss
plt.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.8)

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss History')

plt.legend()

plt.show()

In [None]:
plt.figure(figsize=(8, 6))

plt.plot(train_acc_history, label='train accuracy')
plt.plot(val_acc_history, label='val accuracy')

# the epoch with best validation accuracy
plt.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.8)

plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Accuracy History')

plt.legend()

plt.show()

### Test the model

In [None]:
# test the model
test_preds = []

model.load_state_dict(torch.load('rna_type_checkpoint.pt')['model_state_dict'])

model.eval()

for batch in test_loader:
    x, y = batch
    x, y = x.to(device).float(), y.to(device).long()

    output = model(x)

    _, y_pred = torch.max(output.data, 1)  # argmax in y_pred
    # print(y_pred.shape)

    test_preds.append(y_pred.cpu().numpy())


test_preds = np.concatenate(test_preds)

total = len(y_test)
correct = np.sum(test_preds == y_test)

print(f'total number of test data: {total}, correct={correct}, test acc={correct/total:.4f}')