In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  5116  100  5116    0     0  30634      0 --:--:-- --:--:-- --:--:-- 30634


In [None]:
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

Updating... This may take around 2 minutes.
Updating TPU runtime to pytorch-nightly ...
Collecting cloud-tpu-client
  Downloading https://files.pythonhosted.org/packages/56/9f/7b1958c2886db06feb5de5b2c191096f9e619914b6c31fdf93999fdbbd8b/cloud_tpu_client-0.10-py3-none-any.whl
Collecting google-api-python-client==1.8.0
[?25l  Downloading https://files.pythonhosted.org/packages/9a/b4/a955f393b838bc47cbb6ae4643b9d0f90333d3b4db4dc1e819f36aad18cc/google_api_python_client-1.8.0-py3-none-any.whl (57kB)
[K     |████████████████████████████████| 61kB 3.2MB/s 
Uninstalling torch-1.7.0+cu101:
Installing collected packages: google-api-python-client, cloud-tpu-client
  Found existing installation: google-api-python-client 1.7.12
    Uninstalling google-api-python-client-1.7.12:
      Successfully uninstalled google-api-python-client-1.7.12
Successfully installed cloud-tpu-client-0.10 google-api-python-client-1.8.0
Done updating TPU runtime
  Successfully uninstalled torch-1.7.0+cu101
Uninstalling 

In [None]:
!pip install pytorch-lightning

Collecting pytorch-lightning
[?25l  Downloading https://files.pythonhosted.org/packages/c3/3d/fffaf4f83633552249a40d3d366f460f44539bce0592c568d8ee20d782fa/pytorch_lightning-1.1.6-py3-none-any.whl (687kB)
[K     |████████████████████████████████| 696kB 5.0MB/s 
Collecting PyYAML!=5.4.*,>=5.1
[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)
[K     |████████████████████████████████| 276kB 10.2MB/s 
Collecting fsspec[http]>=0.8.1
[?25l  Downloading https://files.pythonhosted.org/packages/ec/80/72ac0982cc833945fada4b76c52f0f65435ba4d53bc9317d1c70b5f7e7d5/fsspec-0.8.5-py3-none-any.whl (98kB)
[K     |████████████████████████████████| 102kB 8.2MB/s 
[?25hCollecting future>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)
[K     |████████████████████████████████| 829kB 13

In [None]:
MAX_LEN = 26
LR = 0.0001
BATCH_SIZE = 4096

import os
import time
import json
from tqdm import tqdm
import numpy as np
import pandas as pd

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.callbacks.progress import ProgressBar
from pytorch_lightning.loggers import CSVLogger

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader, TensorDataset

from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, roc_auc_score, roc_curve, auc, matthews_corrcoef, plot_confusion_matrix
from collections import defaultdict
import seaborn as sns

In [None]:
class CNNRegression(pl.LightningModule):
    def __init__(self):
        super(CNNRegression, self).__init__()
        self.out_predictions = []
        self.forward_flag = 0
        
        self.conv2d_block = nn.Sequential(
            nn.Conv2d(2, 256, (4, 9)),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Dropout2d(0.2),
        )
        
        self.conv1d_block = nn.Sequential(
            nn.Conv1d(256, 128, 9),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Conv1d(128, 64, 3),
            nn.ReLU(),
            nn.BatchNorm1d(64),
        )
            
        self.lin_block = nn.Sequential(
            nn.Linear(64 * 8, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        if not self.forward_flag:
            self.forward_flag = 1
            self.start = time.time()
        x = self.conv2d_block(x)
        x = self.conv1d_block(torch.squeeze(x))
        x = x.view(x.size(0), -1)
        x = self.lin_block(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        loss = F.mse_loss(torch.squeeze(self(x)), y)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=LR)
            
    def validation_step(self, batch, batch_idx):
        x, y = batch
        val_loss = F.mse_loss(torch.squeeze(self(x)), y)
        self.log('val_loss', val_loss)
        return val_loss
            
    def test_step(self, batch, batch_idx):
        x, y = batch
        test_out = self(x)
        return 0
  
    def on_test_epoch_end(self):
        end = time.time()
        print(end - self.start)

In [None]:
def encode(seq, max_len):
    nucl_dict = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    mat = np.zeros((4, max_len), dtype=int)
    
    for i, nucl in enumerate(seq):
        mat[nucl_dict[nucl]][i] = 1

    return mat

def encode_pair(seq1, seq2, max_len):
    enc1 = encode(seq1, max_len)
    enc2 = encode(seq2, max_len)
    return np.array([enc1, enc2])

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
train = json.load(open('train.json'))
val = json.load(open('val.json'))
test = json.load(open('test.json'))
y_train = np.load('y_train.npy')
y_val = np.load('y_val.npy')
y_test = np.load('y_test.npy')

In [None]:
y_train = y_train * 100
y_val = y_val * 100
y_test = y_test * 100

X_train = [encode_pair(item[0], item[1], MAX_LEN) for item in train]
X_val = [encode_pair(item[0], item[1], MAX_LEN) for item in val]
X_test = [encode_pair(item[0], item[1], MAX_LEN) for item in test]

X_train = np.array(X_train, dtype=np.dtype('d'))
X_val = np.array(X_val, dtype=np.dtype('d'))
X_test = np.array(X_test, dtype=np.dtype('d'))

In [None]:
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)

In [None]:
train_dataloader = DataLoader(TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train)), batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_dataloader = DataLoader(TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val)), batch_size=BATCH_SIZE, num_workers=0)
test_dataloader = DataLoader(TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test)), batch_size=BATCH_SIZE, num_workers=0)

In [None]:
model = CNNRegression.load_from_checkpoint('...')
model.eval()

CNNRegression(
  (conv2d_block): Sequential(
    (0): Conv2d(2, 256, kernel_size=(4, 9), stride=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout2d(p=0.2, inplace=False)
  )
  (conv1d_block): Sequential(
    (0): Conv1d(256, 128, kernel_size=(9,), stride=(1,))
    (1): ReLU()
    (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv1d(128, 64, kernel_size=(3,), stride=(1,))
    (4): ReLU()
    (5): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (lin_block): Sequential(
    (0): Linear(in_features=512, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, out_features=1, bias=True)
  )
)

In [None]:
trainer = pl.Trainer(tpu_cores=8)

GPU available: False, used: False
TPU available: True, using: 1 TPU cores


In [None]:
repetitions = 10
timings=np.zeros((repetitions, 1))

# MEASURE PERFORMANCE
with torch.no_grad():
    for rep in range(repetitions):
        trainer.test(model, test_dataloaders=test_dataloader)

training on 8 TPU cores
INIT TPU local core: 0, global rank: 0 with XLA_USE_BF16=None
INIT TPU local core: 6, global rank: 6 with XLA_USE_BF16=None
INIT TPU local core: 3, global rank: 3 with XLA_USE_BF16=None
INIT TPU local core: 1, global rank: 1 with XLA_USE_BF16=None
INIT TPU local core: 2, global rank: 2 with XLA_USE_BF16=None
INIT TPU local core: 5, global rank: 5 with XLA_USE_BF16=None
INIT TPU local core: 7, global rank: 7 with XLA_USE_BF16=None
INIT TPU local core: 4, global rank: 4 with XLA_USE_BF16=None


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

0.5515551567077637
1.4018564224243164
0.9029159545898438
0.9251351356506348
1.1174182891845703
1.0471866130828857
1.4676234722137451
1.0700621604919434

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{}
--------------------------------------------------------------------------------


training on 8 TPU cores
INIT TPU local core: 0, global rank: 0 with XLA_USE_BF16=None
INIT TPU local core: 5, global rank: 5 with XLA_USE_BF16=None
INIT TPU local core: 3, global rank: 3 with XLA_USE_BF16=None
INIT TPU local core: 6, global rank: 6 with XLA_USE_BF16=None
INIT TPU local core: 7, global rank: 7 with XLA_USE_BF16=None
INIT TPU local core: 1, global rank: 1 with XLA_USE_BF16=None
INIT TPU local core: 4, global rank: 4 with XLA_USE_BF16=None
INIT TPU local core: 2, global rank: 2 with XLA_USE_BF16=None


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

1.094714641571045
0.42415285110473633
0.5295510292053223
0.6572864055633545
1.0608184337615967
1.2986581325531006
0.9821577072143555
0.9410858154296875

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{}
--------------------------------------------------------------------------------


training on 8 TPU cores
INIT TPU local core: 0, global rank: 0 with XLA_USE_BF16=None
INIT TPU local core: 2, global rank: 2 with XLA_USE_BF16=None
INIT TPU local core: 3, global rank: 3 with XLA_USE_BF16=None
INIT TPU local core: 4, global rank: 4 with XLA_USE_BF16=None
INIT TPU local core: 7, global rank: 7 with XLA_USE_BF16=None
INIT TPU local core: 1, global rank: 1 with XLA_USE_BF16=None
INIT TPU local core: 5, global rank: 5 with XLA_USE_BF16=None
INIT TPU local core: 6, global rank: 6 with XLA_USE_BF16=None


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

1.0733416080474854
0.9919965267181396
1.086693286895752
1.0294976234436035
1.1506454944610596
1.1307377815246582
1.1616685390472412
1.0808813571929932

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{}
--------------------------------------------------------------------------------


training on 8 TPU cores
INIT TPU local core: 0, global rank: 0 with XLA_USE_BF16=None
INIT TPU local core: 4, global rank: 4 with XLA_USE_BF16=None
INIT TPU local core: 1, global rank: 1 with XLA_USE_BF16=None
INIT TPU local core: 6, global rank: 6 with XLA_USE_BF16=None
INIT TPU local core: 2, global rank: 2 with XLA_USE_BF16=None
INIT TPU local core: 5, global rank: 5 with XLA_USE_BF16=None
INIT TPU local core: 3, global rank: 3 with XLA_USE_BF16=None
INIT TPU local core: 7, global rank: 7 with XLA_USE_BF16=None


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

0.7641496658325195
0.8060812950134277
1.3052723407745361
0.8828670978546143
1.1088132858276367
1.2262179851531982
1.177123785018921
1.059922218322754

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{}
--------------------------------------------------------------------------------


training on 8 TPU cores
INIT TPU local core: 0, global rank: 0 with XLA_USE_BF16=None
INIT TPU local core: 7, global rank: 7 with XLA_USE_BF16=None
INIT TPU local core: 1, global rank: 1 with XLA_USE_BF16=None
INIT TPU local core: 3, global rank: 3 with XLA_USE_BF16=None
INIT TPU local core: 6, global rank: 6 with XLA_USE_BF16=None
INIT TPU local core: 4, global rank: 4 with XLA_USE_BF16=None
INIT TPU local core: 2, global rank: 2 with XLA_USE_BF16=None
INIT TPU local core: 5, global rank: 5 with XLA_USE_BF16=None


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

1.0619242191314697
1.1413476467132568
0.9638183116912842
0.9176235198974609
1.170250415802002
1.061676263809204
1.1099920272827148
1.0522587299346924

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{}
--------------------------------------------------------------------------------


training on 8 TPU cores
INIT TPU local core: 0, global rank: 0 with XLA_USE_BF16=None
INIT TPU local core: 6, global rank: 6 with XLA_USE_BF16=None
INIT TPU local core: 3, global rank: 3 with XLA_USE_BF16=None
INIT TPU local core: 7, global rank: 7 with XLA_USE_BF16=None
INIT TPU local core: 1, global rank: 1 with XLA_USE_BF16=None
INIT TPU local core: 4, global rank: 4 with XLA_USE_BF16=None
INIT TPU local core: 5, global rank: 5 with XLA_USE_BF16=None
INIT TPU local core: 2, global rank: 2 with XLA_USE_BF16=None


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

0.4677567481994629
0.5735352039337158
0.9353816509246826
1.335836410522461
1.2793536186218262
0.977691650390625
1.0682671070098877
1.0007519721984863

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{}
--------------------------------------------------------------------------------


training on 8 TPU cores
INIT TPU local core: 0, global rank: 0 with XLA_USE_BF16=None
INIT TPU local core: 6, global rank: 6 with XLA_USE_BF16=None
INIT TPU local core: 4, global rank: 4 with XLA_USE_BF16=None
INIT TPU local core: 3, global rank: 3 with XLA_USE_BF16=None
INIT TPU local core: 5, global rank: 5 with XLA_USE_BF16=None
INIT TPU local core: 2, global rank: 2 with XLA_USE_BF16=None
INIT TPU local core: 1, global rank: 1 with XLA_USE_BF16=None
INIT TPU local core: 7, global rank: 7 with XLA_USE_BF16=None


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

1.0800731182098389
0.6799407005310059
1.1756517887115479
1.0312867164611816
1.0083742141723633
1.1368184089660645
0.9817469120025635
1.5881175994873047

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{}
--------------------------------------------------------------------------------


training on 8 TPU cores
INIT TPU local core: 0, global rank: 0 with XLA_USE_BF16=None
INIT TPU local core: 2, global rank: 2 with XLA_USE_BF16=None
INIT TPU local core: 3, global rank: 3 with XLA_USE_BF16=None
INIT TPU local core: 5, global rank: 5 with XLA_USE_BF16=None
INIT TPU local core: 4, global rank: 4 with XLA_USE_BF16=None
INIT TPU local core: 7, global rank: 7 with XLA_USE_BF16=None
INIT TPU local core: 1, global rank: 1 with XLA_USE_BF16=None
INIT TPU local core: 6, global rank: 6 with XLA_USE_BF16=None


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

0.6588265895843506
0.5835988521575928
0.6526613235473633
1.0789673328399658
0.8110852241516113
0.8561956882476807

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
1.1263761520385742
{}
--------------------------------------------------------------------------------
0.9357857704162598


training on 8 TPU cores
INIT TPU local core: 0, global rank: 0 with XLA_USE_BF16=None
INIT TPU local core: 4, global rank: 4 with XLA_USE_BF16=None
INIT TPU local core: 5, global rank: 5 with XLA_USE_BF16=None
INIT TPU local core: 2, global rank: 2 with XLA_USE_BF16=None
INIT TPU local core: 1, global rank: 1 with XLA_USE_BF16=None
INIT TPU local core: 3, global rank: 3 with XLA_USE_BF16=None
INIT TPU local core: 6, global rank: 6 with XLA_USE_BF16=None
INIT TPU local core: 7, global rank: 7 with XLA_USE_BF16=None


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

1.0713238716125488
0.6763300895690918
1.017430067062378
0.8804104328155518
1.0535149574279785
1.1053080558776855
1.2095463275909424
0.9715063571929932

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{}
--------------------------------------------------------------------------------


training on 8 TPU cores
INIT TPU local core: 0, global rank: 0 with XLA_USE_BF16=None
INIT TPU local core: 2, global rank: 2 with XLA_USE_BF16=None
INIT TPU local core: 3, global rank: 3 with XLA_USE_BF16=None
INIT TPU local core: 5, global rank: 5 with XLA_USE_BF16=None
INIT TPU local core: 1, global rank: 1 with XLA_USE_BF16=None
INIT TPU local core: 4, global rank: 4 with XLA_USE_BF16=None
INIT TPU local core: 6, global rank: 6 with XLA_USE_BF16=None
INIT TPU local core: 7, global rank: 7 with XLA_USE_BF16=None


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

0.6723718643188477
0.8023910522460938
0.5736591815948486
1.1058402061462402
0.9704806804656982
1.1754088401794434
1.0661413669586182
1.0358450412750244

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{}
--------------------------------------------------------------------------------


