In [1]:
MAX_LEN = 26
LR = 0.0001
NUM_BATCHES = 256

In [2]:
import os
import numpy as np
import pandas as pd
import deepdish as dd

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CometLogger
from comet_ml import Experiment

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader

In [102]:
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, roc_auc_score, roc_curve, matthews_corrcoef, plot_confusion_matrix
from collections import defaultdict
import seaborn as sns

In [None]:
def encode(seq, max_len):
    nucl_dict = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    mat = np.zeros((4, max_len), dtype=int)
    
    for i, nucl in enumerate(seq):
        mat[nucl_dict[nucl]][i] = 1

    return mat

def encode_pair(seq1, seq2, max_len):
    enc1 = encode(seq1, max_len)
    enc2 = encode(seq2, max_len)
    return np.array([enc1, enc2])

In [48]:
class CNNRegression(pl.LightningModule):
    def __init__(self):
        super(CNNRegression, self).__init__()
        self.out_predictions = []
        
        self.conv2d_block = nn.Sequential(
            nn.Conv2d(2, 512, (4, 9)),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.Dropout2d(0.2),
        )
        
        self.conv1d_block = nn.Sequential(
            nn.Conv1d(512, 512, 9),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Conv1d(512, 128, 3),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.2),
            nn.Conv1d(128, 128, 3),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Conv1d(128, 64, 1),
            nn.ReLU(),
            nn.BatchNorm1d(64),
        )
            
        self.lin_block = nn.Sequential(
            nn.Linear(64 * 6, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        x = self.conv2d_block(x)
        x = self.conv1d_block(torch.squeeze(x))
        x = x.view(x.size(0), -1)
        x = self.lin_block(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        loss = F.mse_loss(torch.squeeze(self(x)), y)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=LR)
            
    def validation_step(self, batch, batch_idx):
        x, y = batch
        val_loss = F.mse_loss(torch.squeeze(self(x)), y)
        self.log('val_loss', val_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return val_loss
            
    def test_step(self, batch, batch_idx):
        x, y = batch
        test_out = self(x)
        self.out_predictions.append(test_out)
        test_loss = F.mse_loss(torch.squeeze(test_out), y)
        self.log('test_loss', test_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return test_loss

In [4]:
train = dd.io.load('splits/train.h5')
val = dd.io.load('splits/val.h5')
test = dd.io.load('splits/test.h5')
y_train = np.load('splits/y_train.npy')
y_val = np.load('splits/y_val.npy')
y_test = np.load('splits/y_test.npy')

In [10]:
y_train = y_train * 100
y_val = y_val * 100
y_test = y_test * 100

In [5]:
X_train = [encode_pair(item[0], item[1], MAX_LEN) for item in train]
X_val = [encode_pair(item[0], item[1], MAX_LEN) for item in val]
X_test = [encode_pair(item[0], item[1], MAX_LEN) for item in test]

In [27]:
X_train = np.array(X_train, dtype=np.dtype('d'))
X_val = np.array(X_val, dtype=np.dtype('d'))
X_test = np.array(X_test, dtype=np.dtype('d'))

In [21]:
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)

In [28]:
train_dataloader = DataLoader(list(zip(X_train, y_train)), batch_size=NUM_BATCHES, shuffle=True)
val_dataloader = DataLoader(list(zip(X_val, y_val)), batch_size=NUM_BATCHES)
test_dataloader = DataLoader(list(zip(X_test, y_test)), batch_size=NUM_BATCHES)

In [30]:
model = CNNRegression()
early_stopping = EarlyStopping('val_loss', patience=3)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='lightning_checkpoints_cnn/',
    filename='cnn-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min',
)

CometLogger will be initialized in online mode


In [31]:
trainer = pl.Trainer(callbacks=[early_stopping, checkpoint_callback], gpus=[0], logger=comet_logger)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


In [33]:
trainer.fit(model, train_dataloader, val_dataloader)

COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/davidbuterez/hybridisation/37423c2c587041728cc02e9c4744de1d
COMET INFO:   Uploads:
COMET INFO:     environment details : 1
COMET INFO:     filename            : 1
COMET INFO:     installed packages  : 1
COMET INFO: ---------------------------
COMET INFO: Experiment is live on comet.ml https://www.comet.ml/davidbuterez/hybridisation/74a3f3e300374815a04dfd8dea6d1a45


  | Name         | Type       | Params
--------------------------------------------
0 | conv2d_block | Sequential | 38.4 K
1 | conv1d_block | Sequential | 2.6 M 
2 | lin_block    | Sequential | 131 K 
--------------------------------------------
2.8 M     Trainable params
0         Non-trainable params
2.8 M     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/davidbuterez/hybridisation/74a3f3e300374815a04dfd8dea6d1a45
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     train_loss [13] : (88.30232238769531, 456.5487976074219)
COMET INFO:     val_loss [13]   : (74.71394348144531, 155.08792114257812)
COMET INFO:   Uploads:
COMET INFO:     code                : 1 (7 KB)
COMET INFO:     environment details : 1
COMET INFO:     filename            : 1
COMET INFO:     installed packages  : 1
COMET INFO:     notebook            : 1
COMET INFO: ---------------------------





COMET INFO: Uploading stats to Comet before program termination (may take several seconds)


1