In [1]:
import numpy as np
import numpy.random as npr
import pandas as pd
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from pytorch_lightning import LightningDataModule
from torch.utils.data import Dataset, Subset, DataLoader
from src.cdhit import CDHIT, cdhit_split
from src.constants import MSConstants
from src.torch_helpers import zero_padding_collate, NamedTensorDataset
from src.model import PositionalEncoding
from sklearn.model_selection import train_test_split
C = MSConstants()

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class PeptideDataModule(LightningDataModule):
    def __init__(
        self,
        df, 
        batch_size,
        train_val_split,
        cdhit_threshold,
        cdhit_word_length,
        num_workers=1,
        random_state=0
    ):
        self.df = df
        self.batch_size = batch_size
        self.train_val_split = train_val_split
        self.cdhit_threshold = cdhit_threshold
        self.cdhit_word_length = cdhit_word_length
        self.num_workers = num_workers
        self.random_state = 0
    
    def setup(self, stage=None):
        self.sequences = df['sequence'].tolist()
        
        self.dataset = NamedTensorDataset(
            sequence=df['sequence'],
            x=df['sequence'].map(lambda s: np.array([C.alphabet.index(c) for c in s])),
            x_mask=df['sequence'].map(lambda s: np.array([1 for c in s])),
            y=df.iloc[:,1:].fillna(0).values,  #?#?#???#@$?#?@ ?#$??@!??$ ?%#
            y_mask=~np.isnan(df.iloc[:,1:].values)
        )
        
        train_seqs, val_seqs, train_idxs, val_idxs = cdhit_split(
            self.sequences,
            range(len(self.sequences)),
            split=self.train_val_split,
            threshold=self.cdhit_threshold,
            word_length=self.cdhit_word_length,
            random_state=self.random_state
        )
        self.train_dataset = Subset(self.dataset, train_idxs)
        self.val_dataset = Subset(self.dataset, val_idxs)
        
    def train_dataloader(self):
        dataloader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            collate_fn=zero_padding_collate,
            num_workers=self.num_workers,
            shuffle=True,
            drop_last=True
        )
        return dataloader

    def val_dataloader(self):
        dataloader = DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            collate_fn=zero_padding_collate,
            num_workers=self.num_workers,
            shuffle=False,
            drop_last=False
        )
        return dataloader
    
    def predict_dataloader(self, shuffle=False):
        dataloader = DataLoader(
            self.val_dataset,
            batch_size=1,
            collate_fn=zero_padding_collate,
            num_workers=1,
            shuffle=shuffle,
            drop_last=False
        )
        return dataloader

In [5]:
df = pd.read_csv('./data/dbaasp.csv')#[['sequence','Escherichia coli']].dropna()
output_dim = df.shape[1]-1

dm = PeptideDataModule(
    df,
    batch_size=64,
    train_val_split=0.8,
    cdhit_threshold=0.5,
    cdhit_word_length=3,
    num_workers=4
)

In [6]:
# from src.torch_helpers import start_tensorboard

# start_tensorboard(login_node='login-2')

In [7]:
import torch
from torch import nn
from pytorch_lightning import LightningModule

class PeptideTransformer(LightningModule):
    def __init__(
        self,
        residues,
        output_dim,
        model_dim,
        model_depth,
        num_heads,
        lr,
        dropout,
        max_length,
        encoder_weights=None,
        residue_weights=None,
        train_encoder=True,
        train_residues=True
    ):
        super().__init__()
        self.save_hyperparameters()
        
        self.residues = residues
        self.model_dim = model_dim
        self.output_dim = output_dim
        self.max_length = max_length
        self.model_depth = model_depth
        self.num_heads = num_heads
        self.dropout = dropout
        self.lr = lr
        self.train_encoder = train_encoder
        self.train_residues = train_residues
        
        self.residue_embedding = nn.Embedding(
            len(self.residues)+1, 
            model_dim,
            padding_idx=0
        )
        
        self.positional_encoding = PositionalEncoding(
            d_model=model_dim,
            max_len=2*max_length, # striding
            dropout=dropout
        ).requires_grad_(False)
        
        self.transformer = nn.Transformer(
            d_model=model_dim,
            nhead=num_heads, 
            num_encoder_layers=model_depth, 
            num_decoder_layers=model_depth,
            dim_feedforward=model_dim,
            dropout=dropout,
            batch_first=True
        )
        self.transformer.decoder = None
        
        clf_layers = []
        for i in range(model_depth-1):
            clf_layers.append(nn.Linear(model_dim, model_dim))
            clf_layers.append(nn.ReLU(inplace=True))
            clf_layers.append(nn.BatchNorm1d(model_dim))
        clf_layers.append(nn.Linear(model_dim,output_dim))
        self.classifier = nn.Sequential(*clf_layers)

#         self.classifier = nn.Linear(model_dim, output_dim)
        
        if residue_weights is not None:
            self.residue_embedding.load_state_dict(residue_weights)
        if encoder_weights is not None:
            self.transformer.encoder.load_state_dict(encoder_weights)
            
        self.residue_embedding.requires_grad_(train_residues)
        self.transformer.encoder.requires_grad_(train_encoder)
            
    def _encode_src(self, sequence, sequence_mask):
        batch_size, max_residues = sequence.shape
        # prepend CLS token
        cls_token = len(self.residues) * torch.ones_like(sequence[:,[0]])
        x = torch.cat([cls_token,sequence],axis=1)
        x_mask = torch.cat([cls_token.bool(),sequence_mask],axis=1)
        x = self.residue_embedding(x)
        x[:,1:] = self.positional_encoding(x[:,1:], offset=0, stride=2)
        return x, x_mask
    
    def forward(self, sequence, sequence_mask):
        batch_size, max_residues = sequence.shape
        max_bonds = max_residues - 1
        
        x_src, x_src_mask = model._encode_src(sequence, sequence_mask)
        
        z = self.transformer.encoder(
            src = x_src,
            src_key_padding_mask = ~x_src_mask,
        )
        # cls token
        z = z[:,0]
        
        y_pred = self.classifier(z)
        
        return y_pred
    
    def step(self, batch, predict_step=False):
        batch_size = batch['x'].shape[0]

        y = batch['y'].float()
        y_mask = batch['y_mask'].bool()
        
        y_pred = self(
            sequence=batch['x'].long(),
            sequence_mask=batch['x_mask'].bool()
        )
        
        if predict_step:
            return y_pred

        # probably gonna have to mask here too
        loss = ((y_pred - y).square() * y_mask).sum() / y_mask.sum()
        
        return loss
    
    def training_step(self, batch, batch_idx):
        batch_size = batch['x'].shape[0]
        loss = self.step(batch)
        assert not torch.isnan(loss).any().item(), batch_idx
        self.log('train_mse',loss,batch_size=batch_size)
        return loss
    
    def validation_step(self, batch, batch_idx):
        batch_size = batch['x'].shape[0]
        loss = self.step(batch)
        self.log('valid_mse',loss,batch_size=batch_size,sync_dist=True)
        
    def predict_step(self, batch, batch_idx=None):
        return self.step(batch, predict_step=True)
    
    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=self.lr)
        return opt

# just show that init'ing from weights trained on MS gives better val err...

In [9]:
PRETRAIN = True

if PRETRAIN:
    [last_ckpt] = !ls -t1 ./lightning_logs/version_15649130/checkpoints/*.ckpt | head -n1
    #pretrained_checkpoint = './lightning_logs/version_15655593/checkpoints/epoch=37-step=24927.ckpt'
    pretrained_checkpoint = last_ckpt
    state_dict = torch.load(pretrained_checkpoint,map_location=torch.device('cpu'))['state_dict']

    name = 'residue_embedding'
    residue_weights = {k.replace(name+'.',''):v for k,v in state_dict.items() if k.startswith(name)}
    
    name = 'transformer.encoder'
    encoder_weights = {k.replace(name+'.',''):v for k,v in state_dict.items() if k.startswith(name)}
else:
    residue_weights = encoder_weights = None

In [10]:
torch.manual_seed(0)

model = PeptideTransformer(
    residues=C.alphabet,
    model_dim=512,
    output_dim=output_dim,
    model_depth=4,
    num_heads=4,
    lr=1e-4,
    dropout=0.1,
    max_length=100,
    residue_weights=residue_weights,
    encoder_weights=encoder_weights,
    train_encoder=False,
    train_residues=False
)

In [11]:
from pytorch_lightning import Trainer

# !rm -rf ./lightning_logs/`ls -t ./lightning_logs | head -n1`
trainer = Trainer(
    gpus=0,
    max_epochs=100,
    precision=32,
    log_every_n_steps=10
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
trainer.fit(model, dm)

Set SLURM handle signals.

  | Name                | Type               | Params
-----------------------------------------------------------
0 | residue_embedding   | Embedding          | 12.8 K
1 | positional_encoding | PositionalEncoding | 0     
2 | transformer         | Transformer        | 6.3 M 
3 | classifier          | Linear             | 2.6 K 
-----------------------------------------------------------
2.6 K     Trainable params
6.3 M     Non-trainable params
6.3 M     Total params
25.313    Total estimated model params size (MB)


Epoch 0:  78%|███████▊  | 32/41 [00:01<00:00, 17.47it/s, loss=5.23, v_num=1.57e+7]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/9 [00:00<?, ?it/s][A
Epoch 0:  83%|████████▎ | 34/41 [00:02<00:00, 15.12it/s, loss=5.23, v_num=1.57e+7]
Epoch 0:  90%|█████████ | 37/41 [00:02<00:00, 15.41it/s, loss=5.23, v_num=1.57e+7]
Epoch 0: 100%|██████████| 41/41 [00:02<00:00, 15.92it/s, loss=5.23, v_num=1.57e+7]
Epoch 1:  80%|████████  | 33/41 [00:01<00:00, 16.54it/s, loss=3.95, v_num=1.57e+7]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/9 [00:00<?, ?it/s][A
Validating:  11%|█         | 1/9 [00:00<00:02,  3.57it/s][A
Epoch 1:  88%|████████▊ | 36/41 [00:02<00:00, 14.85it/s, loss=3.95, v_num=1.57e+7]
Epoch 1:  95%|█████████▌| 39/41 [00:02<00:00, 15.43it/s, loss=3.95, v_num=1.57e+7]
Epoch 1: 100%|██████████| 41/41 [00:02<00:00, 15.24it/s, loss=3.95, v_num=1.57e+7]
Epoch 2:  80%|████████  | 33/41 [00:01<00:00, 17.87it/s, loss=3.43, v_num=1.57e+7]
Validating: 0it [0

In [None]:
# dm.setup()
# model = model.cpu()
# model.eval();