In [1]:
import numpy as np
import numpy.random as npr
import pandas as pd
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from pytorch_lightning import LightningDataModule
from torch.utils.data import Dataset, Subset, DataLoader
from src.cdhit import CDHIT
from src.constants import MSConstants
from src.torch_helpers import zero_padding_collate, NamedTensorDataset
from src.model import PositionalEncoding
from sklearn.model_selection import train_test_split
C = MSConstants()

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class PeptideDataModule(LightningDataModule):
    def __init__(
        self,
        df, 
        batch_size,
        train_val_split,
        cdhit_threshold,
        cdhit_word_length,
        num_workers=1,
        random_state=0
    ):
        self.df = df
        self.batch_size = batch_size
        self.train_val_split = train_val_split
        self.cdhit_threshold = cdhit_threshold
        self.cdhit_word_length = cdhit_word_length
        self.num_workers = num_workers
        self.random_state = 0
    
    def setup(self, stage=None):
        self.sequences = set(df['sequence'])
        
        self.dataset = NamedTensorDataset(
            sequence=df['sequence'],
            x=df['sequence'].map(lambda s: np.array([C.alphabet.index(c) for c in s])),
            x_mask=df['sequence'].map(lambda s: np.array([1 for c in s])),
            y=df.iloc[:,1:].fillna(0).values,  #?#?#???#@$?#?@ ?#$??@!??$ ?%#
            y_mask=~np.isnan(df.iloc[:,1:].values)
        )
        
        cdhit = CDHIT(
            threshold=self.cdhit_threshold,
            word_length=self.cdhit_word_length
        )
        clusters = cdhit.fit_predict(list(set(self.sequences)))
        train_clusters, val_clusters = train_test_split(
            clusters, 
            train_size=self.train_val_split,
            random_state=self.random_state
        )
        train_clusters = set(train_clusters)
        val_clusters = set(val_clusters)
        self.train_sequences = [s for s, c in zip(self.sequences, clusters) if c in train_clusters]
        self.val_sequences = [s for s, c in zip(self.sequences, clusters) if c in val_clusters]
        train_idxs = [i for i, (s,c) in enumerate(zip(self.sequences, clusters)) if c in train_clusters]
        val_idxs = [i for i, (s,c) in enumerate(zip(self.sequences, clusters)) if c in val_clusters]
        self.train_dataset = Subset(self.dataset, train_idxs)
        self.val_dataset = Subset(self.dataset, val_idxs)
        
    def train_dataloader(self):
        dataloader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            collate_fn=zero_padding_collate,
            num_workers=self.num_workers,
            shuffle=True,
            drop_last=True
        )
        return dataloader

    def val_dataloader(self):
        dataloader = DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            collate_fn=zero_padding_collate,
            num_workers=self.num_workers,
            shuffle=False,
            drop_last=False
        )
        return dataloader
    
    def predict_dataloader(self, shuffle=False):
        dataloader = DataLoader(
            self.val_dataset,
            batch_size=1,
            collate_fn=zero_padding_collate,
            num_workers=1,
            shuffle=shuffle,
            drop_last=False
        )
        return dataloader

In [5]:
# from src.torch_helpers import start_tensorboard

# start_tensorboard(login_node='login-2')

In [6]:
df = pd.read_csv('./data/dbaasp.csv')
dm = PeptideDataModule(
    df,
    batch_size=64,
    train_val_split=0.9,
    cdhit_threshold=0.5,
    cdhit_word_length=3,
    num_workers=4
)
dm.setup()
for batch in dm.train_dataloader():
    break

In [15]:
import torch
from torch import nn
from pytorch_lightning import LightningModule

class PeptideTransformer(LightningModule):
    def __init__(
        self,
        residues,
        output_dim,
        model_dim,
        model_depth,
        num_heads,
        lr,
        dropout,
        max_length
    ):
        super().__init__()
        self.save_hyperparameters()
        
        self.residues = residues
        self.model_dim = model_dim
        self.output_dim = output_dim
        self.max_length = max_length
        self.model_depth = model_depth
        self.num_heads = num_heads
        self.dropout = dropout
        self.lr = lr
        
        self.residue_embedding = nn.Embedding(
            len(self.residues), 
            model_dim,
            padding_idx=0
        )
        
        self.positional_encoding = PositionalEncoding(
            d_model=model_dim,
            max_len=2*max_length, # striding
            dropout=dropout
        ).requires_grad_(False)

        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=model_dim, 
                nhead=num_heads, 
                dim_feedforward=model_dim,
                dropout=dropout,
                batch_first=True
            ),
            num_layers=model_depth,
            norm=nn.LayerNorm(model_dim, eps=1e-5)
        )
        
        clf_layers = []
#         for i in range(model_depth-1):
#             clf_layers.append(nn.Linear(model_dim, model_dim))
#             clf_layers.append(nn.ReLU(inplace=True))
#             clf_layers.append(nn.BatchNorm1d(model_dim))
        clf_layers.append(nn.Linear(model_dim,output_dim))
        self.classifier = nn.Sequential(*clf_layers)
            
    def _encode_src(self, sequence):
        x = self.residue_embedding(sequence)
        x = self.positional_encoding(x, offset=0, stride=2)
        return x
    
    def forward(self, sequence, sequence_mask=None):
        batch_size, max_residues = sequence.shape
        max_bonds = max_residues - 1
        
        if sequence_mask is None:
            sequence_mask = torch.ones(batch_size, max_residues, device=self.device)
        x_src_mask = sequence_mask
        
        x_src = model._encode_src(sequence)
        x_src *= x_src_mask.unsqueeze(-1) # unsure

        z = self.encoder(src=x_src, src_key_padding_mask=~x_src_mask)
        z = z.mean(1) # average pool
        y_pred = self.classifier(z)
        
        return y_pred
    
    def step(self, batch, predict_step=False):
        batch_size = batch['x'].shape[0]

        y = batch['y'].float()
        y_mask = batch['y_mask'].bool()
        
        y_pred = self(
            sequence=batch['x'].long(),
            sequence_mask=batch['x_mask'].bool()
        )
        
        if predict_step:
            return y_pred

        # probably gonna have to mask here too
        loss = ((y_pred - y).square() * y_mask).sum() / y_mask.sum()
        
        return loss
    
    def training_step(self, batch, batch_idx):
        batch_size = batch['x'].shape[0]
        loss = self.step(batch)
        assert not torch.isnan(loss).any().item(), batch_idx
        self.log('train_mse',loss,batch_size=batch_size)
        return loss
    
    def validation_step(self, batch, batch_idx):
        batch_size = batch['x'].shape[0]
        loss = self.step(batch)
        self.log('valid_mse',loss,batch_size=batch_size,sync_dist=True)
        
    def predict_step(self, batch, batch_idx=None):
        return self.step(batch, predict_step=True)
    
    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=self.lr)
        return opt

In [16]:
model = PeptideTransformer(
    residues=C.alphabet,
    model_dim=256,
    output_dim=5,
    model_depth=4,
    num_heads=4,
    lr=1e-4,
    dropout=0.1,
    max_length=100
)

In [17]:
from pytorch_lightning import Trainer

!rm -rf ./lightning_logs/`ls -t ./lightning_logs | head -n1`
trainer = Trainer(
    gpus=1,
    max_epochs=10000,
    precision=32,
    # num_sanity_val_steps=0,
    log_every_n_steps=5
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
trainer.fit(model, dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [GPU-b72b8c0d-1f71-551f-8762-96059ca70389]
Set SLURM handle signals.

  | Name                | Type               | Params
-----------------------------------------------------------
0 | residue_embedding   | Embedding          | 6.1 K 
1 | positional_encoding | PositionalEncoding | 0     
2 | encoder             | TransformerEncoder | 1.6 M 
3 | classifier          | Sequential         | 1.3 K 
-----------------------------------------------------------
1.6 M     Trainable params
0         Non-trainable params
1.6 M     Total params
6.364     Total estimated model params size (MB)


Epoch 0:  59%|█████▉    | 69/117 [00:03<00:02, 19.84it/s, loss=3.02, v_num=1.56e+7]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/49 [00:00<?, ?it/s][A
Epoch 0:  65%|██████▍   | 76/117 [00:03<00:02, 20.25it/s, loss=3.02, v_num=1.56e+7]
Epoch 0:  85%|████████▍ | 99/117 [00:03<00:00, 25.69it/s, loss=3.02, v_num=1.56e+7]
Epoch 0: 100%|██████████| 117/117 [00:03<00:00, 29.45it/s, loss=3.02, v_num=1.56e+7]
Epoch 1:  59%|█████▉    | 69/117 [00:01<00:00, 55.37it/s, loss=2.73, v_num=1.56e+7] 
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/49 [00:00<?, ?it/s][A
Validating:   2%|▏         | 1/49 [00:00<00:09,  4.93it/s][A
Epoch 1:  79%|███████▊  | 92/117 [00:01<00:00, 58.77it/s, loss=2.73, v_num=1.56e+7]
Epoch 1: 100%|██████████| 117/117 [00:01<00:00, 68.44it/s, loss=2.73, v_num=1.56e+7]
Epoch 2:  59%|█████▉    | 69/117 [00:01<00:00, 53.07it/s, loss=2.6, v_num=1.56e+7]  
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/49 [00:00<?, ?it/s][

In [None]:
dm.setup()
model = model.cpu()
model.eval();