In [4]:
MAX_LEN = 26
LR = 0.0001
NUM_BATCHES = 512

import os
import time
import json
from tqdm import tqdm
import numpy as np
import pandas as pd

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.callbacks.progress import ProgressBar
from pytorch_lightning.loggers import CSVLogger

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader, TensorDataset

from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, roc_auc_score, roc_curve, matthews_corrcoef, plot_confusion_matrix
from collections import defaultdict
import seaborn as sns

In [25]:
class CNNRegression(pl.LightningModule):
    def __init__(self):
        super(CNNRegression, self).__init__()
        self.out_predictions = []
        self.forward_flag = 0
        
        self.conv2d_block = nn.Sequential(
            nn.Conv2d(2, 256, (4, 9)),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Dropout2d(0.2),
        )
        
        self.conv1d_block = nn.Sequential(
            nn.Conv1d(256, 128, 9),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Conv1d(128, 64, 3),
            nn.ReLU(),
            nn.BatchNorm1d(64),
        )
            
        self.lin_block = nn.Sequential(
            nn.Linear(64 * 8, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        # if not self.forward_flag:
        #   self.forward_flag = 1
        #   self.start = time.time()
        x = self.conv2d_block(x)
        x = self.conv1d_block(torch.squeeze(x))
        x = x.view(x.size(0), -1)
        x = self.lin_block(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        loss = F.mse_loss(torch.squeeze(self(x)), y)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=LR)
            
    def validation_step(self, batch, batch_idx):
        x, y = batch
        val_loss = F.mse_loss(torch.squeeze(self(x)), y)
        self.log('val_loss', val_loss)
        return val_loss
            
    def test_step(self, batch, batch_idx):
        x, y = batch
        # self.start = time.time()
        test_out = self(x)
        # end = time.time()

        # result = np.longdouble(end) - np.longdouble(start)
        # print(result)

        self.out_predictions.append(test_out)
        test_loss = F.mse_loss(torch.squeeze(test_out), y)
        self.log('test_loss', test_loss)
        return test_loss
        # return 0
  
    # def on_test_epoch_end(self):
    #     end = time.time()
    #     print(end - self.start)

In [6]:
def encode(seq, max_len):
    nucl_dict = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    mat = np.zeros((4, max_len), dtype=int)
    
    for i, nucl in enumerate(seq):
        mat[nucl_dict[nucl]][i] = 1

    return mat

def encode_pair(seq1, seq2, max_len):
    enc1 = encode(seq1, max_len)
    enc2 = encode(seq2, max_len)
    return np.array([enc1, enc2])

In [7]:
# Replace json with dd as in other files
train = json.load(open('../splits/train.json'))
val = json.load(open('../splits/val.json'))
test = json.load(open('../splits/test.json'))
y_train = np.load('../splits/y_train.npy')
y_val = np.load('../splits/y_val.npy')
y_test = np.load('../splits/y_test.npy')

In [8]:
y_train = y_train * 100
y_val = y_val * 100
y_test = y_test * 100

X_train = [encode_pair(item[0], item[1], MAX_LEN) for item in train]
X_val = [encode_pair(item[0], item[1], MAX_LEN) for item in val]
X_test = [encode_pair(item[0], item[1], MAX_LEN) for item in test]

X_train = np.array(X_train, dtype=np.dtype('d'))
X_val = np.array(X_val, dtype=np.dtype('d'))
X_test = np.array(X_test, dtype=np.dtype('d'))

In [17]:
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)

In [18]:
train_dataloader = DataLoader(TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train)), batch_size=NUM_BATCHES, shuffle=True, num_workers=0)
val_dataloader = DataLoader(TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val)), batch_size=NUM_BATCHES, num_workers=0)
test_dataloader = DataLoader(TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test)), batch_size=NUM_BATCHES, num_workers=0)

In [26]:
model = CNNRegression()
early_stopping = EarlyStopping('val_loss', patience=3)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='lightning_checkpoints_cnn_lite/',
    filename='cnn-lite-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min',
)

logger = CSVLogger('logs_cnn_lite', name='cnn_lite')

In [27]:
trainer = pl.Trainer(gpus=[0], callbacks=[early_stopping, checkpoint_callback], logger=logger)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [28]:
trainer.fit(model, train_dataloader, val_dataloader)


  | Name         | Type       | Params
--------------------------------------------
0 | conv2d_block | Sequential | 19.2 K
1 | conv1d_block | Sequential | 320 K 
2 | lin_block    | Sequential | 131 K 
--------------------------------------------
470 K     Trainable params
0         Non-trainable params
470 K     Total params


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

1