In [1]:
import numpy as np
import numpy.random as npr
import pandas as pd
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
mito_df = pd.read_csv('./data/mitochondria_targeting.csv')
mito_df = mito_df[['Sequence','Mitochondrial Targeting Signal']].drop_duplicates(keep='first')
cdc28_df = pd.read_csv('./data/cdc28_binding.csv')
cdc28_df = cdc28_df[['Sequence','Cdc28 Binding']].drop_duplicates(keep='first')
# seem to be same sequences.
df = mito_df.merge(cdc28_df,on='Sequence',how='inner')
# df = df.loc[df['Sequence'].map(len)<=50] # per the review from Kevin
df = df.loc[df['Sequence'].map(len)<=1000]
df

Unnamed: 0,Sequence,Mitochondrial Targeting Signal,Cdc28 Binding
0,MPAVLRTRSKESSIEQKPASRTRTRSRRGKRGRDDDDDDDDEESDD...,1,0
1,EQKWQDEQELKKKEKELKRKNDAEAKRLRMEERKRQQMQKKIAKEQ...,0,0
2,IEKFKTKKIKAKLKADQKLNKEDAKPGSDVEKEVSFNPLF,0,0
3,MQKISKYSSMAILRKRPLVKTETGPESELLPEKRTKIKQEEVVPQPVD,0,0
4,RELNVEAEINVKHEEKTVEETMVKLENDISVKVED,0,0
...,...,...,...
5344,MSDYEEAFNDGNENFEDFDVEHFSDEETYEEKPQFKDGETTDANGK...,0,0
5345,PPEGHKKTEKETDIKDVDETNEDEVKDRVEDEVKDRVEDEVKDQDE...,0,0
5346,MDELLGEALSAENQTGESTVESEKLVTPEDVMTIS,0,0
5347,PLSDLKKRSQAKMNAKTDFAKIINKPNELSQILTVDPKT,0,0


In [4]:
from src.torch_helpers import NamedTensorDataset
from src.datamodule import PeptideDataModule
from src.constants import MSConstants

C = MSConstants()

dataset = NamedTensorDataset(
    sequence=df['Sequence'],
    x=df['Sequence'].map(lambda s: np.array([C.alphabet.index(c) for c in s])),
    x_mask=df['Sequence'].map(lambda s: np.array([1 for c in s])),
#     y=df[['Mitochondrial Targeting Signal','Cdc28 Binding']].astype(np.int32).values
    y=df[['Mitochondrial Targeting Signal']].astype(np.int32).values
)

dm = PeptideDataModule(
    dataset,
    batch_size=256,
    val_batch_size=1024,
    train_val_split=0.9,
    cdhit_threshold=0.5,
    cdhit_word_length=3
)

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
import torch
from torch import nn
from torch.nn import functional as F
import pytorch_lightning as pl
from torchmetrics.functional import auroc

class CNNModel(pl.LightningModule):
    def __init__(
        self,
        output_dim,
        model_dim,
        model_depth,
        kernel_size,
        num_residues,
        dropout,
        balance_classes,
        lr
    ):
        super().__init__()
        
        self.model_dim = model_dim
        self.model_depth = model_depth
        self.kernel_size = kernel_size
        self.num_residues = num_residues
        self.output_dim = output_dim
        self.dropout = dropout
        self.balance_classes = balance_classes
        self.lr = lr
        
        self.embedding = nn.Embedding(
            num_embeddings=num_residues,
            embedding_dim=model_dim,
            padding_idx=0
        )

        encoder_layers = []
        in_dim = model_dim
        for i in range(model_depth-1):
            out_dim = in_dim // 2
            drop = nn.Dropout(dropout)
            conv = nn.Conv1d(in_dim,out_dim,kernel_size,padding=kernel_size//2)
            norm = nn.BatchNorm1d(out_dim)
            relu = nn.LeakyReLU(0.2, inplace=True)
            pool = nn.AvgPool1d(2,2)
            encoder_layers += [drop, conv, norm, relu, pool]
            in_dim = in_dim // 2
        self.encoder = nn.Sequential(*encoder_layers)
        
        self.pooling = nn.AdaptiveAvgPool1d(1)

        self.classifier = nn.Linear(out_dim, self.output_dim)

    def forward(self, x):
        x = self.embedding(x)
        x = x.swapdims(1,2)
        x = self.encoder(x)
        x = self.pooling(x).squeeze(-1)
        x = self.classifier(x)
        return x
    
    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(),lr=self.lr)
        return opt
    
    def step(self, batch, batch_idx):
        x = batch['x']
        y = batch['y']
        y_pred = self(x)
        losses = []
        aucs = []
        for k in range(self.output_dim):
            if self.balance_classes:
                pos_weight = (1+(y[:,k]==0).sum()) / (1+(y[:,k]==1).sum())
            else:
                pos_weight = None
            loss = F.binary_cross_entropy_with_logits(y_pred[:,k], y[:,k].float(), pos_weight=pos_weight)
            auc = auroc(y_pred[:,k], y[:,k])
            losses.append(loss)
            aucs.append(auc)
        loss = torch.stack(losses).mean()
        return loss, aucs
    
    def training_step(self, batch, batch_idx):
        batch_size = batch['x'].shape[0]
        loss, aucs = self.step(batch, batch_idx)
        self.log('train_loss',loss,batch_size=batch_size)
        for k in range(self.output_dim):
            self.log(f'train_auc_{k}',aucs[k],batch_size=batch_size)
        return loss
    
    def validation_step(self, batch, batch_idx):
        batch_size = batch['x'].shape[0]
        loss, aucs = self.step(batch, batch_idx)
        self.log('valid_loss',loss,sync_dist=True,batch_size=batch_size)
        for k in range(self.output_dim):
            self.log(f'valid_auc_{k}',aucs[k],sync_dist=True,batch_size=batch_size)

In [14]:
torch.manual_seed(0)

model = CNNModel(
    output_dim = len(dm.dataset[0]['y']),
    model_dim = 64,
    model_depth = 3,
    num_residues = len(C.alphabet),
    kernel_size = 3,
    dropout = 0.1,
    lr = 5e-4,
    balance_classes = True
)

In [15]:
!rm -rf ./lightning_logs/version_$SLURM_JOBID

In [16]:
from src.torch_helpers import NoValProgressBar

trainer = pl.Trainer(
    gpus=0,
    precision=32,
    max_epochs=100,
    callbacks=[NoValProgressBar()]
)

trainer.fit(model, dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_deprecation(
Set SLURM handle signals.

  | Name       | Type              | Params
-------------------------------------------------
0 | embedding  | Embedding         | 1.5 K 
1 | encoder    | Sequential        | 7.8 K 
2 | pooling    | AdaptiveAvgPool1d | 0     
3 | classifier | Linear            | 17    
-------------------------------------------------
9.4 K     Trainable params
0         Non-trainable params
9.4 K     Total params
0.038     Total estimated model params size (MB)


                                                                      

  rank_zero_warn(
  rank_zero_warn(


Epoch 16:  58%|█████▊    | 11/19 [00:01<00:00, 10.23it/s, loss=1.06, v_num=1.66e+7]



Epoch 99: 100%|██████████| 19/19 [00:01<00:00, 12.00it/s, loss=0.627, v_num=1.66e+7]


In [None]:
# trainer.predict(model, dm.val_dataloader())

In [17]:
#version = hparams['version']

!mv ./lightning_logs/version_$SLURM_JOBID ./lightning_logs/cnn_singletask