In [1]:
import argparse
import json
import os
import math

from catalyst import dl
from catalyst.dl import utils
from catalyst.core.callbacks.scheduler import SchedulerCallback
from catalyst.contrib.dl.callbacks.neptune_logger import NeptuneLogger
from catalyst.contrib.nn.schedulers.onecycle import OneCycleLRWithWarmup
from catalyst.contrib.nn.optimizers.ralamb import Ralamb
from gensim.models import KeyedVectors
import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import StratifiedKFold
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

# from torchcontrib.optim import SWA

from constants import FilePaths, TGT_COLS
from datasets import RNAAugData, RNAAugDatav2
from modellib import RNNmodels
from modellib.RNNmodels import ParamModel, Conv2dBn, Conv1dBn, onehot
from modellib.transformermodels import Conv2D1x1, TransformerCustomEncoder, PositionalEncoding
from nn_trainer import MCRMSE


In [2]:
NUM_WORKERS = 8
NUM_FOLDS = 5

BATCH_SIZE = 12
FP = FilePaths("data")
train = pd.read_json(FP.train_json, lines=True)
cvlist = list(
    StratifiedKFold(NUM_FOLDS, shuffle=True, random_state=1234786).split(
        train, train["SN_filter"]
    )
)

device = utils.get_device()


In [3]:
hparams = json.load(open("hparams.json"))
hparams

{'batch_size': 16,
 'num_epochs': 150,
 'num_folds': 5,
 'seed': 1234786,
 'seq_emb_dim': 64,
 'struct_emb_dim': 64,
 'pl_emb_dim': 64,
 'combined_emb_dim': 150,
 'gru_dim': 256,
 'rnn_type': 'lstm',
 'gru_layers': 3,
 'bidirectional': True,
 'dropout_prob': 0.75,
 'spatial_dropout': 0.5,
 'target_dim': 5,
 'num_features': 3,
 'max_seq_pred': 68,
 'lr': 0.0005,
 'wd': 1e-08,
 'filter_sn': True,
 'add_error_noise': True,
 'model_name': 'RCNNGRUModelv6',
 'conv_channels': [128, 128, 256, 256, 256],
 'bpp_conv_channels': [128, 128, 256, 256, 256],
 'kernel_size': 5,
 'stride': 1,
 'add_bpp': True,
 'use_one_hot': False,
 'sig_factor': 1.0,
 'run_on_single': False,
 'optimizer': 'adam',
 'scheduler': 'reducelrplateau',
 'conv_drop': 0.75,
 'use_codon': False,
 'prob_thresh': 0.0,
 'signal_to_noise': 1.0,
 'use_augment': True,
 'loss_func': 'mcrmse'}

In [4]:
for fold_num in [0]:
    tr_idx, val_idx = cvlist[fold_num]
    tr, vl = train.iloc[tr_idx], train.iloc[val_idx]
    if hparams.get("filter_sn"):
        tr = tr.loc[tr["signal_to_noise"] > hparams.get("signal_to_noise", 1.0)]
        vl = vl.loc[vl["signal_to_noise"] > hparams.get("signal_to_noise", 1.0)]


In [5]:
def train_one_fold(tr, vl, hparams, logger, logdir, device, embeddings):
    tr_ds = RNAAugDatav2(
        tr,
        targets=TGT_COLS,
        augment_strucures= False,  # hparams.get("use_augment", True),
        aug_data_sources=[
                          "data/augmented_data_public/aug_data5.csv",
                          "data/augmented_data_public/aug_data5_10.csv",
                          "data/vienna_7_mec.csv", 
                          "data/vienna_17_mec.csv", 
                          "data/vienna_27_mec.csv",
                          "data/vienna_47_mec.csv",
                          "data/vienna_57_mec.csv",
                          "data/vienna_67_mec.csv"
                           ],
        target_aug=False,
        bpps_path="data/bpps",
    )
    vl_ds = RNAAugDatav2(vl, targets=TGT_COLS, bpps_path="data/bpps")

    tr_dl = DataLoader(tr_ds, shuffle=True, drop_last=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,)
    vl_dl = DataLoader(vl_ds, shuffle=False, drop_last=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,)

    model = RNATransformer(hparams, embeddings.to(device))
    if hparams.get("optimizer", "adam") == "adam":
        optimizer = torch.optim.AdamW(
            model.parameters(), lr=hparams.get("lr", 1e-3), weight_decay=hparams.get("wd", 0),
        )
    else:
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=hparams.get("lr", 1e-3),
            weight_decay=hparams.get("wd", 0),
            momentum=0.9,
            nesterov=True,
        )
    if hparams.get("scheduler", "reducelrplateau") == "reducelrplateau":
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, factor=0.5, min_lr=2e-4)
    if hparams.get("scheduler", "reducelrplateau") == "one_cycle":
        total_steps = hparams.get("num_epochs") * (len(tr) // hparams.get("batch_size"))
        max_lr = hparams.get("lr", 1e-3)
        scheduler = OneCycleLRWithWarmup(
            optimizer, num_steps=total_steps, lr_range=(max_lr, max_lr / 10, max_lr / 100), warmup_fraction=0.5,
        )
    
    if hparams.get("loss_func", "mcrmse") == "mcrmse":
        criterion = MCRMSE()
    elif hparams.get("loss_func") == "mcmsre":
        criterion = MCMSRE()
    runner = dl.SupervisedRunner(device=device)
    runner.train(
        loaders={"train": tr_dl, "valid": vl_dl},
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        num_epochs=hparams.get("num_epochs", 10),
        logdir=logdir,
        verbose=0,
        # callbacks=[logger, SchedulerCallback(mode="epoch")],
        load_best_on_end=True,
        # resume="logs/filter__cnnlstm__posembv5/fold_0/checkpoints/best_full.pth"
    )
    return model, tr_dl, vl_dl


In [6]:
hparams["loss_func"] = "mcrmse"
hparams["batch_size"] = BATCH_SIZE
hparams["num_epochs"] = 300
hparams["lr"] = 5e-4
hparams["wd"] = 1e-5
tags = ["transformer", "test"]
exp_dir = Path("./logs") / "__".join(tags)
exp_dir.mkdir(exist_ok=True, parents=True)
logdir = exp_dir / f"fold_{fold_num}"
logdir.mkdir(exist_ok=True)

# neptune_logger = NeptuneLogger(
#     api_token=os.environ["NEPTUNE_API_TOKEN"],
#     project_name="tezdhar/Covid-RNA-degradation",
#     name="covid_rna_degradation",
#     params=hparams,
#     tags=tags + [f"fold_{fold_num}"],
#     upload_source_files=["*.py", "modellib/*.py"],
# )
logdir

PosixPath('logs/transformer__test/fold_0')

In [13]:
class RNATransformer(ParamModel):
    def __init__(self, hparams, embeddings):
        super().__init__(hparams=hparams)
        emb_dim = 32
        self.sequence_embedding = nn.Embedding(self.num_seq_tokens, emb_dim) # nn.Embedding.from_pretrained(embeddings, freeze=False) 
        # self.sequence_embedding.weight = nn.Parameter(embeddings/8)

        self.structure_embedding = nn.Embedding(self.num_struct_tokens, emb_dim)
        self.predicted_loop_embedding = nn.Embedding(self.num_pl_tokens, emb_dim)

        # xseq = onehot(xinputs["sequence"], self.num_seq_tokens)
        # xstruct = onehot(xinputs["structure"], self.num_struct_tokens)
        # xpl = onehot(xinputs["predicted_loop_type"], self.num_pl_tokens)
        # xpairseq = onehot(xinputs["pair_sequence"], self.num_seq_tokens)

        value_conv_channels1 = [emb_dim*2] + [64, 64, 64] 
        self.value_conv1 = nn.Sequential(*[Conv1dBn(value_conv_channels1[i], value_conv_channels1[i+1], drop=0.1)
                                       for i in range(len(value_conv_channels1)-1)])
        value_conv_channels2 = [emb_dim*2] + [64, 64, 64] 
        self.value_conv2 = nn.Sequential(*[Conv1dBn(value_conv_channels2[i], value_conv_channels2[i+1], drop=0.1)
                                       for i in range(len(value_conv_channels2)-1)])
        #self.value_conv3 = nn.Sequential(*[Conv1dBn(value_conv_channels[i], value_conv_channels[i+1], drop=0.2)
        #                               for i in range(len(value_conv_channels)-1)])
        self.cont_conv = Conv1dBn(4, 32)
        self.pos1 = PositionalEncoding(emb_dim)
        self.pos2 = PositionalEncoding(emb_dim)

        # self.fc11 = nn.Linear
        bpp_input_channels = 15
        bpp_conv_channels = [256, 8]
        d_input = value_conv_channels1[-1] + value_conv_channels2[-1] + emb_dim * 2
        # bpps_single_head1 = SingleHeadStaticAttn(d_input, 512, 0.5)
        # self.bpps_multihead1 = MultiHeadStaticAttn(bpps_single_head1, bpp_input_channels, bpp_conv_channels, 0.5)
        self.bpp_transformer = TransformerCustomEncoder(1, d_input, 512, bpp_input_channels, bpp_conv_channels, 0.1, 0.1)
        encoder = nn.TransformerEncoderLayer(d_input, 8, 512, activation='gelu', dropout=0.1)
        self.transformer = nn.TransformerEncoder(encoder, 3)
        self.drop = nn.Dropout(0.1)
        self.fc1 = nn.Linear(d_input * 2, 512)
        self.fc2 = nn.Linear(512, 5)


    def forward(self, xinputs):
        # transformer layer does weighted sum of all vectors
        # the weight score is Q * K.T, which in our case is similar to bpp matrix
        # the original bpp_matrix has very low probilities with non-normal distribution
        # so we use conv layers to map BPP matric to attention map, (bathc_size, seq_len, heads)
        # the remains value layer V --> we do 1D conv for that

        xseq = self.sequence_embedding(xinputs["sequence"])
        xstruct = self.structure_embedding(xinputs["structure"])
        xpl = self.predicted_loop_embedding(xinputs["predicted_loop_type"])
        xpairseq = self.sequence_embedding(xinputs["pair_sequence"])
        
        xseq = self.value_conv1(torch.cat([xseq, xpairseq], -1).permute(0, 2, 1).contiguous())
        xstruct = self.value_conv2(torch.cat([xstruct, xpl], -1).permute(0, 2, 1).contiguous())
        # xpl = self.value_conv3(xpl.permute(0, 2, 1).contiguous())
        # xpairseq = self.value_conv1(xseq.permute(0, 2, 1).contiguous())

        # V = self.value_conv(xseqs.permute(0, 2, 1).contiguous())  # batch_size, key_dim, seq_length
        xseq = self.pos1(xseq.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous()
        xstruct = self.pos2(xstruct.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous()
        
        V = torch.cat([xseq, xstruct], 1)
        
        # print(xpos.shape)
        xbpps_inp = xinputs["bpps"].permute(0, 3, 1, 2).contiguous()
    
        #V = self.pos(V.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous()
        # x_attn = self.bpps_multihead1(V, xbpps_inp)
        x_attn = self.bpp_transformer(V, xbpps_inp)
        x_attn = x_attn.permute(0,2,1).contiguous()
        x2 = self.transformer(V.permute(0,2,1).contiguous())
        x_attn = torch.cat([x_attn, x2], dim=-1)
        bpps_mean_mean = xbpps_inp.mean(dim=(1, 2)).unsqueeze(2)
        bpps_max_mean = xbpps_inp.max(dim=2).values.mean(dim=1).unsqueeze(2)
        bpps_max_max = xbpps_inp.max(dim=2).values.max(dim=1).values.unsqueeze(2)
        bpps_mean_std = torch.clamp(xbpps_inp.mean(dim=2).std(dim=1).unsqueeze(2), -2, 2)
        # cont_emb = torch.cat([bpps_mean_mean, bpps_max_mean, bpps_max_max, bpps_mean_std], -1)
        # cont_emb = self.cont_conv(cont_emb.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous()
        # x_attn = torch.cat([x_attn, cont_emb], -1)
        x_attn = x_attn[:, :self.max_seq_pred, :]
        x_attn = self.drop(x_attn)
        return self.fc2(self.drop(self.fc1(x_attn)))

In [14]:
w2v_model = KeyedVectors.load("data/w2v_seq_6gram.vectors")
embeddings = [np.zeros(shape=(32,))] + [w2v_model.wv[str(i)] for i in range(len(w2v_model.vocab))]
embeddings = np.vstack(embeddings).astype('float32')


Call to deprecated `wv` (Attribute will be removed in 4.0.0, use self instead).



In [15]:
model, _, _ = train_one_fold(tr, vl, hparams, logger=None, logdir=logdir, device=device, embeddings=torch.tensor(embeddings))


arrays to stack must be passed as a "sequence" type such as list or tuple. Support for non-sequence iterables such as generators is deprecated as of NumPy 1.16 and will raise an error in the future.



[2020-09-27 10:32:51,053] 
1/300 * Epoch 1 (_base): lr=0.0005 | momentum=0.9000
1/300 * Epoch 1 (train): loss=0.3877
1/300 * Epoch 1 (valid): loss=0.3153
[2020-09-27 10:32:51,053] 
1/300 * Epoch 1 (_base): lr=0.0005 | momentum=0.9000
1/300 * Epoch 1 (train): loss=0.3877
1/300 * Epoch 1 (valid): loss=0.3153
[2020-09-27 10:32:51,053] 
1/300 * Epoch 1 (_base): lr=0.0005 | momentum=0.9000
1/300 * Epoch 1 (train): loss=0.3877
1/300 * Epoch 1 (valid): loss=0.3153
[2020-09-27 10:33:18,203] 
2/300 * Epoch 2 (_base): lr=0.0005 | momentum=0.9000
2/300 * Epoch 2 (train): loss=0.3239
2/300 * Epoch 2 (valid): loss=0.2954
[2020-09-27 10:33:18,203] 
2/300 * Epoch 2 (_base): lr=0.0005 | momentum=0.9000
2/300 * Epoch 2 (train): loss=0.3239
2/300 * Epoch 2 (valid): loss=0.2954
[2020-09-27 10:33:18,203] 
2/300 * Epoch 2 (_base): lr=0.0005 | momentum=0.9000
2/300 * Epoch 2 (train): loss=0.3239
2/300 * Epoch 2 (valid): loss=0.2954
[2020-09-27 10:33:45,572] 
3/300 * Epoch 3 (_base): lr=0.0005 | momentum=0.9

KeyboardInterrupt: 

In [1]:
import torch

In [4]:
torch.rand(1, 10, 4).repeat(10, 1, 1).shape

torch.Size([10, 10, 4])