# GNNを用いたCOVID-19 mRNAワクチン分解率予測


## 1. ライブラリとデータセットの準備

今回は [OpenVaccine data hosted on Kaggle](https://www.kaggle.com/competitions/stanford-covid-vaccine/overview)をデータセットとして用います。

In [None]:
!pip install dgl dgllife biopython seaborn transformers

In [None]:
%%bash
# Download RNA stability data
mkdir -p OpenVaccine

wget https://d2125kp0qwrvcx.cloudfront.net/OpenVaccine/train.json -P OpenVaccine
wget https://d2125kp0qwrvcx.cloudfront.net/OpenVaccine/test.json -P OpenVaccine

In [None]:
train_file = 'OpenVaccine/train.json'
test_file = 'OpenVaccine/test.json'

In [None]:
from typing import Tuple, Iterator

import pandas as pd
import numpy as np
import matplotlib.pylab as plt
import json
import seaborn as sns
import os
import random
from tqdm.notebook import tqdm

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils import data
from torch.nn import functional as F
from torch.optim import lr_scheduler

import dgl

In [None]:
'''
Utils for parsing the RNA data
'''
pred_cols = ['reactivity', 'deg_Mg_pH10', 'deg_pH10', 'deg_Mg_50C', 'deg_50C']

token_to_idx = {
    'sequence': {x:i for i, x in enumerate('ACGU')}, # residue_to_idx
    'structure': {x:i for i, x in enumerate('().')},
    'predicted_loop_type': {x:i for i, x in enumerate('BEHIMSX')},
}

def get_couples(structure):
    """
    For each closing parenthesis, I find the matching opening one and store their index in the couples list.
    The assigned list is used to keep track of the assigned opening parenthesis
    """
    opened = [idx for idx, i in enumerate(structure) if i == '(']
    closed = [idx for idx, i in enumerate(structure) if i == ')']

    assert len(opened) == len(closed)
    assigned = []
    couples = []

    for close_idx in closed:
        for open_idx in opened:
            if open_idx < close_idx:
                if open_idx not in assigned:
                    candidate = open_idx
            else:
                break
        assigned.append(candidate)
        couples.append([candidate, close_idx])
        
    assert len(couples) == len(opened)
    return couples


def build_edge_list(couples: list, size: int) -> tuple:
    '''
    Build edge list representation of the grap from `couples`, the output 
    of `get_couples`. The output of this function will be used to for 
    constructing dgl graph. 
    '''
    src, dst = [], []
    for i in range(size):
        if i < size - 1:
            # neigbouring bases are linked as well
            src.append(i), 
            dst.append(i + 1)
        if i > 0:
            src.append(i)
            dst.append(i - 1)
    
    for i, j in couples:
        src.extend([i, j])
        dst.extend([j, i])
    
    return src, dst

def row_to_graph(row: pd.Series) -> dgl.DGLGraph:
    '''
    Process a row in the RNA data frame and convert to
    a dgl.DGLGraph object.
    '''
    couples = get_couples(row['structure'])
    edge_list = build_edge_list(couples, len(row['structure']))
    # build a dgl.graph
    g = dgl.graph(edge_list)
    # one-hot encoding for three types of node features
    node_features = []
    for node_feature_col in token_to_idx:
        # for each node, perform categorical encoding 
        node_feature = torch.tensor([token_to_idx[node_feature_col][x] for x in row[node_feature_col]])
        # then convert to one-hot
        node_feature = F.one_hot(node_feature, num_classes=len(token_to_idx[node_feature_col]))
        node_features.append(node_feature)
    node_features = torch.cat(node_features, axis=1)
    # attach as node features 
    g.ndata['h'] = node_features.to(torch.float32)
    return g

In [None]:
class RNADataset(data.Dataset):
    '''mRNA stability prediction dataset'''
    def __init__(self, df, pred_cols=['reactivity', 'deg_Mg_pH10', 'deg_pH10', 'deg_Mg_50C', 'deg_50C'], is_train=True):
        self.df = df
        self.pred_cols = pred_cols
        self.n_outputs = len(pred_cols)
        self.is_train = is_train
    
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        g = row_to_graph(row)
        
        if self.is_train:
            target = np.array(row[self.pred_cols].values.tolist()).T
            target = torch.tensor(target, dtype=torch.float32) # shape: (n_labeled_nodes, len(pred_cols))

            n_labeled_nodes = target.shape[0]
            n_nodes = g.num_nodes()

            node_labels = torch.zeros([n_nodes, len(self.pred_cols)], dtype=torch.float32)        
            node_labels[:n_labeled_nodes] = target
            g.ndata['target'] = node_labels # shape: (n_nodes, len(pred_cols))

            train_mask = torch.zeros(n_nodes, dtype=torch.bool)
            train_mask[:n_labeled_nodes] = True        
            g.ndata['train_mask'] = train_mask # shape: (n_nodes, )        
        return g

    @property
    def feature_dim(self):
        g = self.__getitem__(0)
        return g.ndata['h'].shape[1]

In [None]:
# parse the data into data frames
train = pd.read_json(train_file, lines=True)
test = pd.read_json(test_file, lines=True)
print(train.shape, test.shape)

In [None]:
train.columns

In [None]:
train.head()

In [None]:
test.columns

In [None]:
test.head()

In [None]:
train_dataset = RNADataset(train)
test_dataset = RNADataset(test, is_train=False)
print(len(train_dataset), len(test_dataset))

In [None]:
# Look at one RNA graph in the dataset:
i = 0
g = train_dataset[i]
g

In [None]:
print('Shape of node features:', g.ndata['h'].shape)
g.ndata['h']

In [None]:
print('Shape of node targets:', g.ndata['target'].shape)
print('labels:', train_dataset.pred_cols)
g.ndata['target'][:10]

### 1.2. RNA分子を可視化

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# get the raw features from the data frame:
seq = train.loc[i, 'sequence']
print(seq)
predicted_loop_type = train.loc[i, 'predicted_loop_type']
print(predicted_loop_type)
structure = train.loc[i, 'structure']

# convert to an undirected networkx graph
mol_graph = dgl.to_networkx(g).to_undirected()
print(mol_graph.number_of_nodes(), mol_graph.number_of_edges())

In [None]:
pos = nx.spring_layout(mol_graph)
nx.draw(mol_graph, pos)

In [None]:
n_residues = mol_graph.number_of_nodes()

# label nodes by the index and residue type
numbered_seq = ['%d%s'%(idx, letter) for idx, letter in zip(range(n_residues), seq)]
node_labels = dict(zip(range(n_residues), numbered_seq))

# color by predicted_loop_type
color_palette = sns.color_palette()
node_colors = [color_palette[token_to_idx['predicted_loop_type'][loop_type]] for loop_type in predicted_loop_type]

nx.draw(mol_graph, pos, 
        labels=node_labels,
        node_color=node_colors
       )

In [None]:
# color by reactivity
reactivities = train.loc[i, 'reactivity'].copy()
# fill 0's for trailing residues
reactivities.extend([0] * (n_residues - len(reactivities)))

nx.draw(mol_graph, pos, 
        labels=node_labels,
        node_color=reactivities,
        cmap='Reds'
       )

## 2. GNNモデルを定義

In [None]:
from dgllife.model import GCN

In [None]:
model_config = {
    'num_layers': 2,
    'hidden_feats': 8,
    'dropout': 0.2,
    'residual': False,
    'batchnorm': False,
}

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 

In [None]:
model = GCN(
    in_feats=train_dataset.feature_dim,
    hidden_feats=[model_config['hidden_feats'] for _ in range(model_config['num_layers'] - 1)] + [train_dataset.n_outputs],
    activation=[F.relu for _ in range(model_config['num_layers'] - 1)] + [None],
    residual=[model_config['residual'] for _ in range(model_config['num_layers'])],
    batchnorm=[model_config['batchnorm'] for _ in range(model_config['num_layers'])],
    dropout=[model_config['dropout'] for _ in range(model_config['num_layers'] - 1)] + [0]
).to(device)
model

## 3. Train/validation split

In [None]:
train_config = {
    'frac_train': 0.8,
    'lr': 1e-3,
    'n_epochs': 10,
    'batch_size': 128,
    'num_workers': 0,
    'seed': 42
}

N = train.shape[0]
train_idx = np.random.choice(N, int(train_config['frac_train'] * N), replace=False)
valid_idx = np.setdiff1d(np.arange(N), train_idx)
print(train_idx.shape, valid_idx.shape)

In [None]:
train_dataset = RNADataset(train.iloc[train_idx])
train_loader = data.DataLoader(train_dataset, 
                               batch_size=train_config['batch_size'], 
                               shuffle=True, 
                               pin_memory=True,
                               num_workers=train_config['num_workers'], 
                               collate_fn=dgl.batch
                              )

valid_dataset = RNADataset(train.iloc[valid_idx])
valid_loader = data.DataLoader(valid_dataset, 
                               batch_size=train_config['batch_size'], 
                               shuffle=False, 
                               pin_memory=True,
                               num_workers=train_config['num_workers'], 
                               collate_fn=dgl.batch
                              )

print(len(train_dataset), len(valid_dataset))

In [None]:
def train_fn(model, train_loader, criterion, optimizer, device):
    '''Train model for one epoch'''
    model.train()
    model.zero_grad()
    train_loss = []
    
    for index, graphs in enumerate(train_loader):
        graphs = graphs.to(device)
        preds = model(graphs, graphs.ndata['h'])
        train_mask = graphs.ndata['train_mask']
        targets = graphs.ndata['target']
        
        loss = criterion(preds[train_mask], targets[train_mask])
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss.append(loss.item())
    
    train_loss_avg = np.mean(train_loss)
    print(f"Train loss {train_loss_avg}")
    return train_loss_avg
    
def eval_fn(model, valid_loader, criterion, device):
    '''Evaluate model'''
    model.eval()
    eval_loss = []
    
    for index, graphs in enumerate(valid_loader):
        graphs = graphs.to(device)
        preds = model(graphs, graphs.ndata['h'])
        train_mask = graphs.ndata['train_mask']
        targets = graphs.ndata['target']
        
        loss = criterion(preds[train_mask], targets[train_mask])
        eval_loss.append(loss.item())
    
    eval_loss_avg = np.mean(eval_loss)
    print(f"Valid loss {eval_loss_avg}")
    return eval_loss_avg


In [None]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=train_config['lr'], weight_decay=0.0)

train_losses = []
eval_losses = []

for epoch in range(train_config['n_epochs']):
    print('#################')
    print('###Epoch:', epoch)

    train_loss = train_fn(model, train_loader, criterion, optimizer, device)
    eval_loss = eval_fn(model, valid_loader, criterion, device)
    train_losses.append(train_loss)
    eval_losses.append(eval_loss)


## 4. テストデータの予測

In [None]:
test_graph = test_dataset[0]
test_graph

In [None]:
model.eval()
predicted_node_labels = model(test_graph.to(device), 
                              test_graph.ndata['h'].to(device))
predicted_node_labels.shape