# Description
Welcome to Stanford Ribonanza RNA Folding challenge. The task of this competition is predicting the chemical reactivity at each position of an RNA molecule. These data are extremely sensitive to the structure that each RNA forms, and an algorithm that could perfectly predict these chemical reactivities would need to have an implicit ‘understanding’ of RNA structure. Such an oracle could be then utilized to predictively model structures of novel RNA molecules. A better understanding of how to manipulate RNA could help usher in an age of programmable medicine, including first cures for pancreatic cancer and Alzheimer’s disease as well as much-needed antibiotics and new biotechnology approaches for climate change. 

This notebook provides a simple baseline that may be used as a starting point for further experiments. Improvment of the baseline may include: 
1. Use of proper loss function to incorporate SN_filter = 0 samples as well as reactivity errors into training
2. Model improvement and use of additional data, e.g. Ribonanza_bpp_files

Finally, working on this competition keep in mind that train/public LB have different sequence length distribution from private LB, i.e. 115-206 vs. 207-457. Therefore, to avoid a strong shakeup one may need to look into performance vs. sequence length end ensure generalizability.

In [1]:
# !git clone https://github.com/Horikitasaku/RNA-FM.git
# import os
# os.chdir('/kaggle/working/RNA-FM/')
# !pip install .
# !mkdir -p /root/.cache/torch/hub/checkpoints/
# !cp /kaggle/input/h-one-for-all-model/RNA-FM_pretrained.pth /root/.cache/torch/hub/checkpoints/

In [2]:
# !git clone https://github.com/Horikitasaku/RNABERT.git

In [3]:
import fm

backbone, alphabet = fm.pretrained.rna_fm_t12()
batch_converter = alphabet.get_batch_converter()
# model.eval()  # disables dropout for deterministic results

In [4]:
import pandas as pd
import os, gc
import numpy as np
from sklearn.model_selection import KFold

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [5]:
import sys
sys.path.append('./RNABERT')
from utils.bert import Load_RNABert_Model
import torch

In [6]:
BERTmodel = Load_RNABert_Model('RNABERT/RNABERT.pth')

device:  cuda
---Loaded---


In [7]:
BERTmodel.eval()

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(6, 120, padding_idx=0)
      (position_embeddings): Embedding(440, 120)
      (token_type_embeddings): Embedding(2, 120)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (selfattn): BertSelfAttention(
              (query): Linear(in_features=120, out_features=120, bias=True)
              (key): Linear(in_features=120, out_features=120, bias=True)
              (value): Linear(in_features=120, out_features=120, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=120, out_features=120, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.0, inplace=False)
           

In [8]:
# Fix fastai bug to enable fp16 training with dictionaries

import torch
from fastai.vision.all import *
def flatten(o):
    "Concatenate all collections and items as a generator"
    for item in o:
        if isinstance(o, dict): yield o[item]; continue
        elif isinstance(item, str): yield item; continue
        try: yield from flatten(item)
        except TypeError: yield item

from torch.cuda.amp import GradScaler, autocast
@delegates(GradScaler)
class MixedPrecision(Callback):
    "Mixed precision training using Pytorch's `autocast` and `GradScaler`"
    order = 10
    def __init__(self, **kwargs): self.kwargs = kwargs
    def before_fit(self): 
        self.autocast,self.learn.scaler,self.scales = autocast(),GradScaler(**self.kwargs),L()
    def before_batch(self): self.autocast.__enter__()
    def after_pred(self):
        if next(flatten(self.pred)).dtype==torch.float16: self.learn.pred = to_float(self.pred)
    def after_loss(self): self.autocast.__exit__(None, None, None)
    def before_backward(self): self.learn.loss_grad = self.scaler.scale(self.loss_grad)
    def before_step(self):
        "Use `self` as a fake optimizer. `self.skipped` will be set to True `after_step` if gradients overflow. "
        self.skipped=True
        self.scaler.step(self)
        if self.skipped: raise CancelStepException()
        self.scales.append(self.scaler.get_scale())
    def after_step(self): self.learn.scaler.update()

    @property 
    def param_groups(self): 
        "Pretend to be an optimizer for `GradScaler`"
        return self.opt.param_groups
    def step(self, *args, **kwargs): 
        "Fake optimizer step to detect whether this batch was skipped from `GradScaler`"
        self.skipped=False
    def after_fit(self): self.autocast,self.learn.scaler,self.scales = None,None,None
        
import fastai
fastai.callback.fp16.MixedPrecision = MixedPrecision


In [9]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [10]:
fname = 'RNA-FM-EM-TRANS'
PATH = './stanford-ribonanza-rna-folding-converted'
OUT = './'
bs = 64
num_workers = 90
SEED = 2023
nfolds = 4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

backbone.to(device)

RNABertModel(
  (embed_tokens): Embedding(25, 640, padding_idx=1)
  (layers): ModuleList(
    (0-11): 12 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=640, out_features=640, bias=True)
        (v_proj): Linear(in_features=640, out_features=640, bias=True)
        (q_proj): Linear(in_features=640, out_features=640, bias=True)
        (out_proj): Linear(in_features=640, out_features=640, bias=True)
      )
      (self_attn_layer_norm): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=640, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=640, bias=True)
      (final_layer_norm): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=240, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (embed_positions): LearnedPositionalEmbedding(1026, 640, padding_idx=1)
  (emb_layer_norm_

# Data

The primary training data is provided in train_data.csv, which contains 821840 RNA sequences and the corresponding reactivity measurements with 2A3_MaP and DMS_MaP methods. The reactivity is reported in columns reactivity_0001 - reactivity_0206 and is set to NaN for the first 26 and the last 21 nucleotides as well as padding for sequences shorter than 206. For faster loading and effective RAM use, I converted the data into a float32 parquet file. 

Evaluation in this competition is performed only on samples with SN_filter = 1 for both measurement methods. In this example, I perform training only on samples wiht SN_filter = 1, which gives a noticeable CV boost but uses only 1/4 of the data (i.e. training on noisy SN_filter = 0 data degrades the performance). A proper consideration of all data as well as reactivity errors may boost the performance.

In this example, I use a simple CV Kfold split. However, given a mismatch in the RNA length between train/public LB vs. private LB data, **it may be important to verify the effect of the sequence length** to avoid a significant shakeup at the private LB.

One of the tricks, well known in NLP community, which I use here, is length matching batch sampling: composing batches of samples of approximately the same length to minimize the overhead caused by padding tokens.

In [11]:
def generate_token_batch(alphabet, seq_strs):
    batch_size = len(seq_strs)
#     max_len = 206
    max_len = max(len(seq_str) for seq_str in seq_strs)
    tokens = torch.empty(
        (
            batch_size,
            max_len
            + int(alphabet.prepend_bos)
            + int(alphabet.append_eos),
        ),
        dtype=torch.int64,
    )
    tokens.fill_(alphabet.padding_idx)
    for i, seq_str in enumerate(seq_strs):              
        if alphabet.prepend_bos:
            tokens[i, 0] = alphabet.cls_idx
        seq = torch.tensor([alphabet.get_idx(s) for s in seq_str], dtype=torch.int64)
        tokens[i, int(alphabet.prepend_bos): len(seq_str)+ int(alphabet.prepend_bos),] = seq
        if alphabet.append_eos:
            tokens[i, len(seq_str) + int(alphabet.prepend_bos)] = alphabet.eos_idx
    return tokens

In [12]:
class RNA_Dataset(Dataset):
    def __init__(self, df, mode='train', seed=2023, fold=0, nfolds=4, 
                 mask_only=False, **kwargs):
        self.seq_map = {'A':0,'C':1,'G':2,'U':3}
        self.Lmax = 206
        df['L'] = df.sequence.apply(len)
        df_2A3 = df.loc[df.experiment_type=='2A3_MaP']
        df_DMS = df.loc[df.experiment_type=='DMS_MaP']
        
        split = list(KFold(n_splits=nfolds, random_state=seed, 
                shuffle=True).split(df_2A3))[fold][0 if mode=='train' else 1]
        df_2A3 = df_2A3.iloc[split].reset_index(drop=True)
        df_DMS = df_DMS.iloc[split].reset_index(drop=True)
        
        m = (df_2A3['SN_filter'].values > 0) & (df_DMS['SN_filter'].values > 0)
        df_2A3 = df_2A3.loc[m].reset_index(drop=True)
        df_DMS = df_DMS.loc[m].reset_index(drop=True)
        
        self.seq = df_2A3['sequence'].values
        self.L = df_2A3['L'].values
        
        self.react_2A3 = df_2A3[[c for c in df_2A3.columns if \
                                 'reactivity_0' in c]].values
        self.react_DMS = df_DMS[[c for c in df_DMS.columns if \
                                 'reactivity_0' in c]].values
        self.react_err_2A3 = df_2A3[[c for c in df_2A3.columns if \
                                 'reactivity_error_0' in c]].values
        self.react_err_DMS = df_DMS[[c for c in df_DMS.columns if \
                                'reactivity_error_0' in c]].values
        self.sn_2A3 = df_2A3['signal_to_noise'].values
        self.sn_DMS = df_DMS['signal_to_noise'].values
        self.mask_only = mask_only
        
    def __len__(self):
        return len(self.seq)  
    
    def __getitem__(self, idx):
        seq = self.seq[idx]
        tokens = seq
        records = BERTmodel.load_data_EMB(tokens)
        if self.mask_only:
            mask = torch.zeros(self.Lmax, dtype=torch.bool)
            mask[:len(seq)] = True
            return {'mask':mask},{'mask':mask}
        
        seq = [self.seq_map[s] for s in seq]
#         print(seq)
#         tokens = generate_token_batch(alphabet,[tokens])
#         seq = np.array(seq)
#         with torch.no_grad():
#             results = backbone(tokens, repr_layers=[12], need_head_weights=False,return_contacts=False)
#             token_embeddings = results["representations"][12]
        mask = torch.zeros(self.Lmax, dtype=torch.bool)
        mask[:len(seq)] = True
        seq = np.pad(seq,(0,self.Lmax-len(seq)))
        
        react = torch.from_numpy(np.stack([self.react_2A3[idx],
                                           self.react_DMS[idx]],-1))
        # print(react.shape)
        react_err = torch.from_numpy(np.stack([self.react_err_2A3[idx],
                                               self.react_err_DMS[idx]],-1))
        sn = torch.FloatTensor([self.sn_2A3[idx],self.sn_DMS[idx]])
        return {'seq':torch.from_numpy(seq),'tokens': tokens,'mask':mask, 'records': records[1]}, \
               {'react':react, 'react_err':react_err,
                'sn':sn, 'mask':mask}
    
class LenMatchBatchSampler(torch.utils.data.BatchSampler):
    def __iter__(self):
        buckets = [[]] * 100
        yielded = 0

        for idx in self.sampler:
            s = self.sampler.data_source[idx]
            if isinstance(s,tuple): L = s[0]["mask"].sum()
            else: L = s["mask"].sum()
            L = max(1,L // 16) 
            if len(buckets[L]) == 0:  buckets[L] = []
            buckets[L].append(idx)
            
            if len(buckets[L]) == self.batch_size:
                batch = list(buckets[L])
                yield batch
                yielded += 1
                buckets[L] = []
                
        batch = []
        leftover = [idx for bucket in buckets for idx in bucket]

        for idx in leftover:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yielded += 1
                yield batch
                batch = []

        if len(batch) > 0 and not self.drop_last:
            yielded += 1
            yield batch
            
def dict_to(x, device='cuda'):
    return {k:x[k].to(device) for k in x}

def to_device(x, device='cuda'):
    return tuple(dict_to(e,device) for e in x)

class DeviceDataLoader:
    def __init__(self, dataloader, device='cuda'):
        self.dataloader = dataloader
        self.device = device
    
    def __len__(self):
        return len(self.dataloader)
    
    def __iter__(self):
        for batch in self.dataloader:
            # for each batch
            batch[0]['tokens'] = generate_token_batch(alphabet,batch[0]['tokens'])
            

            yield tuple(dict_to(x, self.device) for x in batch)

# Model
Transformers are ideal for the considered task because they naturally capture long dependences in RNN, which define the secondary structure of the molecule and the corresponding chemical reactivity. For illustration purposes, below I provide a simple S-size transformer model.

In [13]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim=64, M=10000000):
        super().__init__()
        self.dim = dim
        self.M = M

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.M) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * (-emb))
        emb = x[...,None] * emb[None,...]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        
#         emb = emb.unsqueeze(1) + x.unsqueeze(0)
        return emb

class ShapeEmb(nn.Module):
    def __init__(self, dim=192):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        # print(x.size())
        _, sequence_length = x.size()
        emb = torch.zeros(sequence_length, self.dim, device=x.device)
        return emb

class Pos_emb(nn.Module):
    def __init__(self, max_position_embeddings, hidden_size=512):
        super().__init__()
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        

    def forward(self, x):
        device = x.device
        pos_emb = self.position_embeddings(x)
        return pos_emb

class RNA_Model(nn.Module):
    def __init__(self, dim=512, depth=12, head_size=32, **kwargs):
        super().__init__()
        self.emb = nn.Embedding(4,dim)
        self.BERTmodel = Load_RNABert_Model('RNABERT/RNABERT.pth')
        self.pos_enc = Pos_emb(max_position_embeddings = dim)
        # self.lstm = nn.RNN(input_size=dim, hidden_size=512, batch_first=True, bidirectional=True)
        # self.embedding_layer_norm = nn.LayerNorm(dim)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=dim, nhead=dim//head_size, dim_feedforward=4*dim,
                dropout=0.1, activation=nn.GELU(), batch_first=True, norm_first=True), depth)
        self.proj_out = nn.Sequential(
                nn.Linear(dim,dim),
                nn.GELU(),
                nn.Linear(dim,2),
            )
    
    def forward(self, x0):
        mask = x0['mask']
        records = x0['records']
        Lmax = mask.sum(-1).max()
        mask = mask[:,:Lmax]
        x = x0['seq'][:,:Lmax]
        bert_emb = self.BERTmodel.predict_by_tokens(Lmax,records)
        pos = torch.arange(Lmax, device=x.device).unsqueeze(0)
        pos = self.pos_enc(pos)
#         print(pos.shape)
#         print(x.shape)

        x = self.emb(x)
        x = x + bert_emb + pos
        # x = self.embedding_layer_norm(x)
        x = self.transformer(x, src_key_padding_mask=~mask)
        
        # x, _ = self.lstm(x)

        x = self.proj_out(x)
        
        return x

class RNA_embedding_model(nn.Module):
    def __init__(self, dim=64, depth=12, head_size=32, **kwargs):
        super().__init__()

        # self.emb = nn.Embedding(4,dim)
        # self.pos_enc = SinusoidalPosEmb(dim)
        self.transformer1 = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=dim, nhead=dim//head_size, dim_feedforward=4*dim,
                dropout=0.1, activation=nn.GELU(), batch_first=True, norm_first=True), 6)

    def forward(self, x0):

        x = self.transformer1(x0.squeeze(0))

        return x

In [14]:
class Human5PrimeUTRPredictor(torch.nn.Module):
    """
    contact predictor with inner product
    """
    def __init__(self, alphabet=None, task="rgs", arch="cnn", input_types=["seq", "emb-rnafm"]):
        """
        :param depth_reduction: mean, first
        """       
        super().__init__()     
        self.alphabet = alphabet   # backbone alphabet: pad_idx=1, eos_idx=2, append_eos=True, prepend_bos=True
        self.task = task
        self.arch = arch
        self.input_types = input_types        
        self.padding_mode = "right"
        self.token_len = 500
        self.out_plane = 2
        self.in_channels = 0
        self.main_planes = 640
        dim=512
        depth=12 
        head_size=32
        if "seq" in self.input_types:
            self.in_channels = self.in_channels + 4

        if "emb-rnafm" in self.input_types:
            self.reductio_module = nn.Linear(640, 512)
            self.in_channels = self.in_channels + 512  

        if self.arch == "cnn" and self.in_channels != 0:
            # self.predictor = self.create_1dcnn_for_emd(in_planes=self.in_channels, out_planes= self.out_plane)
            ...
        else:
            raise Exception("Wrong Arch Type")
        # self.RNA_linear = RNA_Model()
        self.RNA_embed = RNA_embedding_model(dim = dim)
        self.backbone, self.alphabet = fm.pretrained.rna_fm_t12()
        self.pos_out = nn.Sequential(
            nn.Linear(512, 128),
            nn.Linear(128, 2),
        )
        self.pos_enc = Pos_emb(max_position_embeddings = dim)
    def forward(self, x):
        tokens = x['tokens'].squeeze(1)
#         print(tokens.shape)
        seq_len = tokens.shape[1]-2
        # self.pos_enc = Pos_Encoder(seq_len = seq_len, dim = self.main_planes)
        # self.pos_enc = Shape_Encoder(seq_len = seq_len, dim = self.main_planes)
        # Base_RNA = self.RNA_linear(x)
#         print(tokens.shape)
        # with torch.no_grad():
        results = backbone(tokens, repr_layers=[12], need_head_weights=False,return_contacts=False)
        inputs = results["representations"][12]
        ensemble_inputs = []
        if "seq" in self.input_types:
            # padding one-hot embedding            
            nest_tokens = (tokens[:, 1:-1] - 4)   # covert token for RNA-FM (20 tokens) to nest version (4 tokens A,U,C,G)
            nest_tokens = torch.nn.functional.pad(nest_tokens, (0, self.token_len - nest_tokens.shape[1]), value=-2)
            token_padding_mask = nest_tokens.ge(0).long()
            one_hot_tokens = torch.nn.functional.one_hot((nest_tokens * token_padding_mask), num_classes=4)
            one_hot_tokens = one_hot_tokens.float() * token_padding_mask.unsqueeze(-1)            
            # reserve padded one-hot embedding
            one_hot_tokens = one_hot_tokens.permute(0, 2, 1)  # B, L, 4
            ensemble_inputs.append(one_hot_tokens)

        if "emb-rnafm" in self.input_types:
            embeddings = inputs
            # padding RNA-FM embedding
            embeddings, padding_masks = self.remove_pend_tokens_1d(tokens, embeddings)  # remove auxiliary tokens
            embeddings = embeddings.squeeze(dim=1)
            batch_size, seqlen, hiddendim = embeddings.size()
            
            # embeddings = torch.nn.functional.pad(embeddings, (0, 0, 0, self.token_len - embeddings.shape[1]))            
            # # channel reduction
            embeddings = self.reductio_module(embeddings)
            # # reserve padded RNA-FM embedding
            # embeddings = embeddings.permute(0, 2, 1)
            ensemble_inputs.append(embeddings)        

        ensemble_inputs = torch.cat(ensemble_inputs, dim=1)  
        

        output = self.pos_out(ensemble_inputs).squeeze(0)
        return output
 
    def create_1dcnn_for_emd(self, in_planes, out_planes):
        main_planes = self.main_planes
        dropout = 0.2
        emb_cnn = nn.Sequential(
            nn.Conv1d(in_planes, main_planes, kernel_size=3, padding=1), 
            ResBlock(main_planes * 1, main_planes * 1, stride=2, dilation=1, conv_layer=nn.Conv1d,
                     norm_layer=nn.BatchNorm1d), 
            ResBlock(main_planes * 1, main_planes * 1, stride=1, dilation=1, conv_layer=nn.Conv1d,
                     norm_layer=nn.BatchNorm1d),  
            ResBlock(main_planes * 1, main_planes * 1, stride=2, dilation=1, conv_layer=nn.Conv1d,
                     norm_layer=nn.BatchNorm1d), 
            ResBlock(main_planes * 1, main_planes * 1, stride=1, dilation=1, conv_layer=nn.Conv1d,
                     norm_layer=nn.BatchNorm1d),  
            ResBlock(main_planes * 1, main_planes * 1, stride=2, dilation=1, conv_layer=nn.Conv1d,
                     norm_layer=nn.BatchNorm1d), 
            ResBlock(main_planes * 1, main_planes * 1, stride=1, dilation=1, conv_layer=nn.Conv1d,
                     norm_layer=nn.BatchNorm1d),       
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Dropout(dropout),
            # nn.Linear(seq_len, 2),
        )
        return emb_cnn
    
    def remove_pend_tokens_1d(self, tokens, seqs):
        padding_masks = tokens.ne(self.alphabet.padding_idx)

        # remove eos token  （suffix first）
        if self.alphabet.append_eos:     # default is right
            eos_masks = tokens.ne(self.alphabet.eos_idx)
            eos_pad_masks = (eos_masks & padding_masks).to(seqs)
            seqs = seqs * eos_pad_masks.unsqueeze(-1)
            seqs = seqs[:, ..., :-1, :]
            padding_masks = padding_masks[:, ..., :-1]

        # remove bos token
        if self.alphabet.prepend_bos:    # default is left
            seqs = seqs[:, ..., 1:, :]
            padding_masks = padding_masks[:, ..., 1:]

        if not padding_masks.any():
            padding_masks = None

        return seqs, padding_masks
    
class Shape_Encoder(nn.Module):
    def __init__(self, seq_len, dim, **kwargs):
        super().__init__()
        self.seq_len = seq_len
        # self.pos_enc = SinusoidalPosEmb(dim=dim).to(device)
        self.pos_enc = ShapeEmb(dim=dim).to(device)
    
    def forward(self, x):

        x = x.unsqueeze(0)
        pos = torch.arange(self.seq_len, device=x.device).unsqueeze(0)
        pos = self.pos_enc(pos).unsqueeze(0)
        # print("x shape:", x.shape)
        # print("pos shape:", pos.shape)

        x = x.unsqueeze(2)
        pos = pos.unsqueeze(1)

        # print("x shape:", x.shape)
        # print("pos shape:", pos.shape)
        x = x + pos
        # print("x2 shape:", x.shape)

        return x

    
class Pos_Encoder(nn.Module):
    def __init__(self, seq_len, dim, **kwargs):
        super().__init__()
        self.seq_len = seq_len
        self.pos_enc = SinusoidalPosEmb(dim=dim).to(device)
        # self.pos_enc = ShapeEmb(dim=dim).to(device)
    
    def forward(self, x):

        # x = x.unsqueeze(0)
        pos = torch.arange(self.seq_len, device=x.device).unsqueeze(0)
        pos = self.pos_enc(pos)
        # print("x shape:", x.shape)
        # print("pos shape:", pos.shape)

        # x = x.unsqueeze(2)
        # pos = pos.unsqueeze(1)

        # print("x shape:", x.shape)
        # print("pos shape:", pos.shape)
        x = x + pos
        # print("x2 shape:", x.shape)

        return x
    
class ResBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        stride=1,
        dilation=1,
        conv_layer=nn.Conv1d,
        norm_layer=nn.BatchNorm1d,
    ):
        super(ResBlock, self).__init__()        
        self.bn1 = norm_layer(in_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = conv_layer(in_channels, out_channels, kernel_size=3, stride=stride, padding=dilation, bias=False)       
        self.bn2 = norm_layer(out_channels)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = conv_layer(out_channels, out_channels, kernel_size=3, padding=dilation, bias=False)

        if stride > 1 or out_channels != in_channels: 
            self.downsample = nn.Sequential(
                conv_layer(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                norm_layer(out_channels),
            )
        else:
            self.downsample = None
            
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.bn1(x)
        out = self.relu1(out)
        out = self.conv1(out)        
        out = self.bn2(out)
        out = self.relu2(out)
        out = self.conv2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return out


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        nn.init.normal_(m.weight, std=0.001)
        if isinstance(m.bias, nn.Parameter):
            nn.init.constant_(m.bias, 0.0)
    elif classname.find('Conv') != -1:
        nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)
    elif classname.find('BatchNorm') != -1:
        if m.affine:
            nn.init.constant_(m.weight, 1.0)
            nn.init.constant_(m.bias, 0.0)

In [15]:
task="rgs"
arch="cnn"
input_items = ["emb-rnafm"]   # ["seq"], ["emb-rnafm"]
model_name = arch.upper() + "_" + "_".join(input_items) 
utr_func_predictor = Human5PrimeUTRPredictor(
    alphabet, task=task, arch=arch, input_types=input_items    
)
utr_func_predictor.apply(weights_init)
utr_func_predictor.to(device)
print("create utr_func_predictor sucessfully")
print(utr_func_predictor)

create utr_func_predictor sucessfully
Human5PrimeUTRPredictor(
  (reductio_module): Linear(in_features=640, out_features=512, bias=True)
  (RNA_embed): RNA_embedding_model(
    (transformer1): TransformerEncoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
          (activation): GELU(approximate='none')
        )
      )
    )
  )
  (backbone): RNABertModel(
  

In [16]:
model = RNA_Model()
model.eval()

device:  cuda
---Loaded---


RNA_Model(
  (emb): Embedding(4, 512)
  (BERTmodel): Load_RNABert_Model(
    (model): BertForMaskedLM(
      (bert): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(6, 120, padding_idx=0)
          (position_embeddings): Embedding(440, 120)
          (token_type_embeddings): Embedding(2, 120)
          (LayerNorm): BertLayerNorm()
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0-5): 6 x BertLayer(
              (attention): BertAttention(
                (selfattn): BertSelfAttention(
                  (query): Linear(in_features=120, out_features=120, bias=True)
                  (key): Linear(in_features=120, out_features=120, bias=True)
                  (value): Linear(in_features=120, out_features=120, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
                )
                (output): BertSelfOutput(
                  (dens

# Loss & Metric
The metric accumulates all predictions and then performs average to be consistent with the competition metric. However, the difference with a simple batch-based average is negligible.

In [17]:
def loss(pred,target):

    p = pred[:,:pred.shape[1]].squeeze(1)
    seq_len = p.size(1)
    y = target['react'][:,:seq_len].clip(0,1)
    # print('p',p.shape)
    # print('y',y.shape)
    loss = F.l1_loss(p, y, reduction='none')
    loss = loss[~torch.isnan(loss)].mean()
#     print('--------',loss,'---------')
    return loss

class MAE(Metric):
    def __init__(self): 
        self.reset()
        
    def reset(self): 
        self.x,self.y = [],[]
        
    def accumulate(self, learn):
        x = learn.pred.squeeze(1)
        y = learn.y['react'][:,:x.size(1)].clip(0, 1)
        x = x.reshape(-1, 2)
        y = y.reshape(-1, 2)
        # print(x.shape)
        # print(y.shape)
        self.x.append(x)
        self.y.append(y)
    @property
    def value(self):
        # print(self.x)
        x = torch.cat(self.x,dim = 0)
        y = torch.cat(self.y,dim = 0)
        # print(x.shape)
        # print(y.shape)
        
        loss = F.l1_loss(x, y, reduction='none')
        loss = loss[~torch.isnan(loss)].mean()
        print('----',loss,'----')
        with open('log.txt', 'a+') as file:
            file.write(f'Loss: {loss}\n')
            
        return loss
    
class PrintMetricsCallback(Callback):
    def after_epoch(self):
        print(f"Epoch: {self.learn.epoch}, Training loss: {self.learn.recorder.losses[-1]}, Validation loss: {self.learn.recorder.values[-1][0]}")


# Training

In [18]:
seed_everything(SEED)
os.makedirs(OUT, exist_ok=True)
df = pd.read_parquet(os.path.join(PATH,'train_data.parquet'))

In [19]:
# {'seq': tensor([2, 2, 2, 0, 0, 1, 2, 0, 1, 3, 1, 2, 0, 2, 3, 0, 2, 0, 2, 3, 1, 2, 0, 0,
#          0, 0, 0, 1, 2, 0, 3, 2, 0, 3, 0, 3, 2, 2, 0, 3, 3, 3, 0, 1, 3, 1, 1, 2,
#          0, 2, 2, 0, 2, 0, 1, 2, 0, 0, 1, 3, 0, 1, 1, 0, 1, 2, 0, 0, 1, 0, 2, 2,
#          2, 2, 0, 0, 0, 1, 3, 1, 3, 0, 1, 1, 1, 2, 3, 2, 2, 1, 2, 3, 1, 3, 1, 1,
#          2, 3, 3, 3, 2, 0, 1, 2, 0, 2, 3, 0, 0, 2, 3, 1, 1, 3, 0, 0, 2, 3, 1, 0,
#          0, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 2, 2, 1, 3, 3, 1, 2, 2, 1, 1,
#          2, 1, 0, 3, 2, 0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0,
#          0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
#  'tokens': tensor([[6, 6, 6, 4, 4, 5, 6, 4, 5, 7, 5, 6, 4, 6, 7, 4, 6, 4, 6, 7, 5, 6, 4, 4,
#           4, 4, 4, 5, 6, 4, 7, 6, 4, 7, 4, 7, 6, 6, 4, 7, 7, 7, 4, 5, 7, 5, 5, 6,
#           4, 6, 6, 4, 6, 4, 5, 6, 4, 4, 5, 7, 4, 5, 5, 4, 5, 6, 4, 4, 5, 4, 6, 6,
#           6, 6, 4, 4, 4, 5, 7, 5, 7, 4, 5, 5, 5, 6, 7, 6, 6, 5, 6, 7, 5, 7, 5, 5,
#           6, 7, 7, 7, 6, 4, 5, 6, 4, 6, 7, 4, 4, 6, 7, 5, 5, 7, 4, 4, 6, 7, 5, 4,
#           4, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 6, 6, 5, 7, 7, 5, 6, 6, 5, 5,
#           6, 5, 4, 7, 6, 4, 4, 4, 4, 6, 4, 4, 4, 5, 4, 4, 5, 4, 4, 5, 4, 4, 5, 4,
#           4, 5]]),
#  'fm_embed': tensor([[[-0.0611,  0.1168,  0.0513,  ...,  0.0102, -0.0466, -0.0495],
#           [ 0.1240, -0.1895,  0.0813,  ..., -0.3088,  0.1768, -0.0184],
#           [ 0.1964, -0.2130,  0.1875,  ..., -0.2599,  0.0222, -0.1515],
#           ...,
#           [ 0.1636,  0.0329,  0.2152,  ..., -0.1018,  0.1041, -0.0279],
#           [ 0.1615, -0.0473,  0.1812,  ..., -0.0754,  0.1097, -0.1988],
#           [-0.1284,  0.0399, -0.0071,  ..., -0.1088,  0.2529, -0.2890]]]),
#  'mask': tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
#          False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False, False, False, False, False, False,
#          False, False, False, False, False, False])}

In [None]:
for fold in [0]: # running multiple folds at kaggle may cause OOM
#     df_sample = df.sample(frac=0.2)  # 选择数据的子集

    ds_train = RNA_Dataset(df, mode='train', fold=fold, nfolds=nfolds)
    ds_train_len = RNA_Dataset(df, mode='train', fold=fold, 
                nfolds=nfolds, mask_only=True)
    sampler_train = torch.utils.data.RandomSampler(ds_train_len)
    len_sampler_train = LenMatchBatchSampler(sampler_train, batch_size=bs,
                drop_last=True)
    dl_train = DeviceDataLoader(torch.utils.data.DataLoader(ds_train, 
                batch_sampler=len_sampler_train, num_workers=num_workers,
                persistent_workers=True), device)

    ds_val = RNA_Dataset(df, mode='eval', fold=fold, nfolds=nfolds)
    ds_val_len = RNA_Dataset(df, mode='eval', fold=fold, nfolds=nfolds, 
               mask_only=True)
    sampler_val = torch.utils.data.SequentialSampler(ds_val_len)
    len_sampler_val = LenMatchBatchSampler(sampler_val, batch_size=bs, 
               drop_last=False) 
    dl_val= DeviceDataLoader(torch.utils.data.DataLoader(ds_val, 
               batch_sampler=len_sampler_val, num_workers=num_workers), device)
    gc.collect()

    data = DataLoaders(dl_train,dl_val)
    model = RNA_Model()   
    model = model.to(device)

    # task="rgs"
    # arch="cnn"
    # input_items = ["emb-rnafm"]   # ["seq"], ["emb-rnafm"]
    # model_name = arch.upper() + "_" + "_".join(input_items) 
    # utr_func_predictor = Human5PrimeUTRPredictor(
    #     alphabet, task=task, arch=arch, input_types=input_items    
    # )
    # utr_func_predictor.apply(weights_init)
    # utr_func_predictor.to(device)
    # print("create utr_func_predictor sucessfully")

    learn = Learner(data, utr_func_predictor, loss_func=loss,cbs=[GradientClip(3.0)],
                metrics=[MAE()]).to_fp16() 
    #fp16 doesn't help at P100 but gives x1.6-1.8 speedup at modern hardware

    learn.fit_one_cycle(100, lr_max=5e-4, wd=0.05, pct_start=0.02)
    torch.save(learn.model.state_dict(),os.path.join(OUT,f'{fname}_{fold}.pth'))
    gc.collect()

device:  cuda
---Loaded---


epoch,train_loss,valid_loss,mae,time
0,0.227002,0.227577,0.227872,08:15


---- tensor(0.2279, device='cuda:0') ----
