In [1]:
MAX_LEN = 26
LR = 0.0005974060251967456
BATCH_SIZE = 128

In [2]:
import os
import numpy as np
import pandas as pd
import deepdish as dd

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader

from torchnlp.encoders.text import CharacterEncoder

In [3]:
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, roc_auc_score, roc_curve, matthews_corrcoef, plot_confusion_matrix
from collections import defaultdict
import seaborn as sns

In [4]:
class RNNRegression(pl.LightningModule):
    def __init__(self, vocab_size, emb_dim, hidden_size, num_layers, dropout):
        super(RNNRegression, self).__init__()
        self.hidden_size = hidden_size
        self.out_predictions = []
        
        self.embeddings = nn.Embedding(vocab_size, emb_dim)
        self.rnn = nn.LSTM(input_size=emb_dim, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(0.21659149581080556)
        
        self.linear = nn.Linear(2 * hidden_size, 1)

    def forward(self, x):
        x = self.embeddings(x.long())
        x, _ = self.rnn(x)
        cls_token_emb = x[:, 0, :]
        x = self.dropout(cls_token_emb)
        x = self.linear(x)
        return x
    

    def training_step(self, batch, batch_idx):
        x, y = batch
        train_out = self(x)
        loss = F.mse_loss(torch.squeeze(train_out), y)
        self.log('train_loss', loss)
        return loss

    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), LR)
    
            
    def validation_step(self, batch, batch_idx):
        x, y = batch
        val_out = self(x)
        val_loss = F.mse_loss(torch.squeeze(val_out), y)
        self.log('val_loss', val_loss)
        return val_loss
    
            
    def test_step(self, batch, batch_idx):
        x, y = batch
        test_out = self(x)
        self.out_predictions.append(test_out)
        test_loss = F.mse_loss(torch.squeeze(test_out), y)
        self.log('test_loss', test_loss)
        return test_loss
    
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            list(zip(X_train, y_train)),
            batch_size=BATCH_SIZE,
            shuffle=True,
        )
    

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            list(zip(X_val, y_val)),
            batch_size=BATCH_SIZE,
            shuffle=False,
        )
    
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            list(zip(X_test, y_test)),
            batch_size=BATCH_SIZE,
            shuffle=False,
        )

In [5]:
# class CNNRegression(pl.LightningModule):
#     def __init__(self):
#         super(CNNRegression, self).__init__()
#         self.out_predictions = None
        
#         self.conv2d_block1 = nn.Sequential(
#             nn.Conv2d(2, 512, (4, 9)),
#             nn.ZeroPad2d((0, 8, 0, 3)),
#             nn.ReLU(),
#             nn.BatchNorm2d(512),
#             nn.Dropout2d(0.2),
#         )
        
#         self.conv2d_block2 = nn.Sequential(
#             nn.Conv2d(512, 512, (1, 9)),
#             nn.ZeroPad2d((0, 8, 0, 0)),
#             nn.MaxPool2d((2, 1)),
#             nn.ReLU(),
#             nn.BatchNorm2d(512),
#         )
        
#         self.conv2d_block3 = nn.Sequential(
#             nn.Conv2d(512, 128, (1, 3)),
#             nn.ZeroPad2d((0, 2, 0, 0)),
#             nn.ReLU(),
#             nn.BatchNorm2d(128),
#             nn.Dropout2d(0.2),
#         )
        
#         self.conv2d_block4 = nn.Sequential(
#             nn.Conv2d(128, 128, (1, 3)),
#             nn.ZeroPad2d((0, 2, 0, 0)),
#             nn.ReLU(),
#             nn.BatchNorm2d(128),
#         )
        
#         self.conv2d_block5 = nn.Sequential(
#             nn.Conv2d(128, 64, (1, 1)),
#             nn.MaxPool2d((2, 3)),
#             nn.ReLU(),
#             nn.BatchNorm2d(64),
#         )
            
#         self.lin_block = nn.Sequential(
#             nn.Linear(512, 256),
#             nn.ReLU(),
#             nn.Linear(256, 128),
#             nn.ReLU(),
#             nn.Dropout(0.2),
#             nn.Linear(128, 1)
#         )

#     def forward(self, x):
#         x = self.conv2d_block1(x)
#         x = self.conv2d_block2(x)
#         x = self.conv2d_block3(x)
#         x = self.conv2d_block4(x)
#         x = self.conv2d_block5(x)
#         x = x.view(x.size(0), -1)
#         x = self.lin_block(x)
#         return x

#     def training_step(self, batch, batch_idx):
#         x, y = batch
#         loss = F.mse_loss(torch.squeeze(self(x)), y)
#         self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
#         return loss

#     def configure_optimizers(self):
#         return torch.optim.Adam(self.parameters(), lr=LR)
            
#     def validation_step(self, batch, batch_idx):
#         x, y = batch
#         val_loss = F.mse_loss(torch.squeeze(self(x)), y)
#         self.log('val_loss', val_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
#         return val_loss
            
#     def test_step(self, batch, batch_idx):
#         x, y = batch
#         test_loss = F.mse_loss(torch.squeeze(self(x)), y)
#         self.log('test_loss', test_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
#         return test_loss
    
#     def test_epoch_end(self, outputs):
#         # do something with the outputs of all test batches
#         all_test_preds = test_step_outputs.predictions
#         self.out_predictions = all_test_preds

In [6]:
train = dd.io.load('splits/train.h5')
val = dd.io.load('splits/val.h5')
test = dd.io.load('splits/test.h5')
y_train = np.load('splits/y_train.npy')
y_val = np.load('splits/y_val.npy')
y_test = np.load('splits/y_test.npy')

In [7]:
y_train = y_train * 100
y_val = y_val * 100
y_test = y_test * 100

In [8]:
all_seqs = set()
all_seqs.update([item[0] for item in train])
all_seqs.update([item[1] for item in train])
all_seqs.update([item[0] for item in val])
all_seqs.update([item[1] for item in val])
all_seqs.update([item[0] for item in test])
all_seqs.update([item[1] for item in test])

In [9]:
encoder = CharacterEncoder(all_seqs)

In [10]:
enc_dict = dict(zip(encoder.vocab, range(len(encoder.vocab))))

In [11]:
from torchnlp.encoders.text.default_reserved_tokens import DEFAULT_SOS_INDEX
from torchnlp.encoders.text.default_reserved_tokens import DEFAULT_EOS_INDEX
from torchnlp.encoders.text.default_reserved_tokens import DEFAULT_PADDING_INDEX

In [12]:
def split(X, y, split):
    X_train, y_train, X_test, y_test = [], [], [], []
    hist = [[], [], []]
    for i, val in enumerate(y):
        if val < 0.1:
            hist[0].append((X[i], val))
        elif val <= 0.90:
            hist[1].append((X[i], val))
        else:
            hist[2].append((X[i], val))
            
    for h in hist:
        np.random.shuffle(h)
        limit = int(len(h) * split)
        d1, d2 = h[:limit], h[limit:]
        for pair in d1:
            X_train.append(pair[0])
            y_train.append(pair[1])
        for pair in d2:
            X_test.append(pair[0])
            y_test.append(pair[1])
            
    return X_train, y_train, X_test, y_test

In [13]:
def encode_for_rnn(seq, max_len):
    # Default padding index is zero for the character encoder
    nucl_dict = {'A': enc_dict['A'], 'C': enc_dict['C'], 'G': enc_dict['G'], 'T': enc_dict['T']}
    mat = np.zeros(max_len, dtype=int)
    
    for i, nucl in enumerate(seq):
        mat[i] = nucl_dict[nucl]
    return mat

def encode_pair_for_rnn(seq1, seq2, max_len):
    enc1 = encode_for_rnn(seq1, max_len)
    enc2 = encode_for_rnn(seq2, max_len)
    return np.hstack((np.array([DEFAULT_SOS_INDEX]), enc1, np.array([DEFAULT_EOS_INDEX]), enc2, np.array([DEFAULT_EOS_INDEX])))

In [14]:
X_train = [encode_pair_for_rnn(item[0], item[1], MAX_LEN) for item in train]
X_val = [encode_pair_for_rnn(item[0], item[1], MAX_LEN) for item in val]
X_test = [encode_pair_for_rnn(item[0], item[1], MAX_LEN) for item in test]

In [15]:
X_train = np.array(X_train, dtype=np.dtype('d'))
X_val = np.array(X_val, dtype=np.dtype('d'))
X_test = np.array(X_test, dtype=np.dtype('d'))

In [16]:
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)

In [17]:
model = RNNRegression(len(encoder.vocab), emb_dim=32, hidden_size=64, num_layers=3, dropout=0.2412375022122436)
early_stopping = EarlyStopping('val_loss', patience=3)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='lightning_checkpoints_rnn/',
    filename='rnn-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min',
)

logger = CSVLogger('logs_rnn', name='rnn')

In [19]:
trainer = pl.Trainer(callbacks=[early_stopping, checkpoint_callback], gpus=[0], logger=logger)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [20]:
torch.cuda.get_device_name(0)

'GeForce RTX 3090'

In [21]:
trainer.fit(model)


  | Name       | Type      | Params
-----------------------------------------
0 | embeddings | Embedding | 288   
1 | rnn        | LSTM      | 248 K 
2 | dropout    | Dropout   | 0     
3 | linear     | Linear    | 129   
-----------------------------------------
249 K     Trainable params
0         Non-trainable params
249 K     Total params


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

1