In [2]:
import numpy as np
import numpy.random as npr
import pandas as pd
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
from pytorch_lightning import LightningDataModule
from torch.utils.data import Dataset, Subset, DataLoader
from src.cdhit import CDHIT, cdhit_split
from src.constants import MSConstants
from src.torch_helpers import zero_padding_collate, NamedTensorDataset
from src.model import PositionalEncoding
from sklearn.model_selection import train_test_split
C = MSConstants()

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
class PeptideDataModule(LightningDataModule):
    def __init__(
        self,
        df, 
        batch_size,
        train_val_split,
        cdhit_threshold,
        cdhit_word_length,
        num_workers=1,
        random_state=0,
        val_batch_size=None,
    ):
        self.df = df
        self.batch_size = batch_size
        self.val_batch_size = batch_size if val_batch_size is None else val_batch_size
        self.train_val_split = train_val_split
        self.cdhit_threshold = cdhit_threshold
        self.cdhit_word_length = cdhit_word_length
        self.num_workers = num_workers
        self.random_state = 0
    
    def setup(self, stage=None):
        self.sequences = df['sequence'].tolist()
        
        self.dataset = NamedTensorDataset(
            sequence=df['sequence'],
            x=df['sequence'].map(lambda s: np.array([C.alphabet.index(c) for c in s])),
            x_mask=df['sequence'].map(lambda s: np.array([1 for c in s])),
            y=df.iloc[:,1:].fillna(0).values,  #?#?#???#@$?#?@ ?#$??@!??$ ?%#
            y_mask=~np.isnan(df.iloc[:,1:].values)
        )
        
        train_seqs, val_seqs, train_idxs, val_idxs = cdhit_split(
            self.sequences,
            range(len(self.sequences)),
            split=self.train_val_split,
            threshold=self.cdhit_threshold,
            word_length=self.cdhit_word_length,
            random_state=self.random_state
        )
        self.train_dataset = Subset(self.dataset, train_idxs)
        self.val_dataset = Subset(self.dataset, val_idxs)
        
    def train_dataloader(self):
        dataloader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            collate_fn=zero_padding_collate,
            num_workers=self.num_workers,
            shuffle=True,
            drop_last=True
        )
        return dataloader

    def val_dataloader(self):
        dataloader = DataLoader(
            self.val_dataset,
            batch_size=self.val_batch_size,
            collate_fn=zero_padding_collate,
            num_workers=self.num_workers,
            shuffle=False,
            drop_last=False
        )
        return dataloader
    
    def predict_dataloader(self, shuffle=False):
        dataloader = DataLoader(
            self.val_dataset,
            batch_size=1,
            collate_fn=zero_padding_collate,
            num_workers=1,
            shuffle=shuffle,
            drop_last=False
        )
        return dataloader

In [6]:
# from src.torch_helpers import start_tensorboard
# start_tensorboard(login_node='login-2')

In [7]:
import torch
from torch import nn
from pytorch_lightning import LightningModule
from torchmetrics.functional import r2_score

class PeptideTransformer(LightningModule):
    def __init__(
        self,
        residues,
        output_dim,
        model_dim,
        model_depth,
        num_heads,
        lr,
        dropout,
        max_length,
        encoder_weights=None,
        residue_weights=None,
        train_encoder=True,
        train_residues=True
    ):
        super().__init__()
        self.save_hyperparameters()
        
        self.residues = residues
        self.model_dim = model_dim
        self.output_dim = output_dim
        self.max_length = max_length
        self.model_depth = model_depth
        self.num_heads = num_heads
        self.dropout = dropout
        self.lr = lr
        self.train_encoder = train_encoder
        self.train_residues = train_residues
        
        self.residue_embedding = nn.Embedding(
            len(self.residues)+1, 
            model_dim,
            padding_idx=0
        )
        
        self.positional_encoding = PositionalEncoding(
            d_model=model_dim,
            max_len=2*max_length, # striding
            dropout=dropout
        ).requires_grad_(False)
        
        self.transformer = nn.Transformer(
            d_model=model_dim,
            nhead=num_heads, 
            num_encoder_layers=model_depth, 
            num_decoder_layers=model_depth,
            dim_feedforward=model_dim,
            dropout=dropout,
            batch_first=True
        )
        self.transformer.decoder = None
        
        self.classifier = nn.Linear(model_dim, output_dim)
        
        if residue_weights is not None:
            self.residue_embedding.load_state_dict(residue_weights)
        if encoder_weights is not None:
            self.transformer.encoder.load_state_dict(encoder_weights)
            
        self.residue_embedding.requires_grad_(train_residues)
        self.transformer.encoder.requires_grad_(train_encoder)
            
    def _encode_src(self, sequence, sequence_mask):
        batch_size, max_residues = sequence.shape
        # prepend CLS token
        cls_token = len(self.residues) * torch.ones_like(sequence[:,[0]])
        x = torch.cat([cls_token,sequence],axis=1)
        x_mask = torch.cat([cls_token.bool(),sequence_mask],axis=1)
        x = self.residue_embedding(x)
        x[:,1:] = self.positional_encoding(x[:,1:], offset=0, stride=2)
        return x, x_mask
    
    def forward(self, sequence, sequence_mask):
        batch_size, max_residues = sequence.shape
        max_bonds = max_residues - 1
        
        x_src, x_src_mask = model._encode_src(
            sequence.long(), 
            sequence_mask.bool()
        )
        
        z = self.transformer.encoder(
            src = x_src,
            src_key_padding_mask = ~x_src_mask,
        )
        # cls token
        z = z[:,0]
        
        y_pred = self.classifier(z)
        
        return y_pred
    
    def _loss_fn(self, y_pred, y, y_mask):
        y = y.float()
        y_mask = y_mask.bool()
        loss = ((y_pred - y).square() * y_mask).sum() / y_mask.sum()
        return loss
    
    def _r2_score(self, y_pred, y, y_mask):
        y_pred = y_pred[:,[0]]
        y = y[:,[0]]
        y_mask = y_mask[:,0]
        y_mask = y_mask.bool()
        return r2_score(y_pred[y_mask],y[y_mask])
    
    def training_step(self, batch, batch_idx):
        batch_size = len(batch['sequence'])
        y_pred = self(
            sequence=batch['x'],
            sequence_mask=batch['x_mask']
        )
        loss = self._loss_fn(
            y_pred,
            batch['y'], 
            batch['y_mask']
        )
        self.log('train_mse',loss,batch_size=batch_size)
        r2 = self._r2_score(
            y_pred,
            batch['y'], 
            batch['y_mask']
        )
        self.log('train_r2',r2,batch_size=batch_size)
        return loss
    
    def validation_step(self, batch, batch_idx):
        batch_size = len(batch['sequence'])
        y_pred = self(
            sequence=batch['x'],
            sequence_mask=batch['x_mask']
        )
        # only report val err on primary task
        loss = self._loss_fn(
            y_pred[:,[0]],
            batch['y'][:,[0]], 
            batch['y_mask'][:,[0]]
        )
        self.log('valid_mse',loss,sync_dist=True,batch_size=batch_size)
        r2 = self._r2_score(
            y_pred,
            batch['y'], 
            batch['y_mask']
        )
        self.log('valid_r2',r2,sync_dist=True,batch_size=batch_size)
        
    def predict_step(self, batch, batch_idx=None):
        y_pred = self(
            sequence=batch['x'],
            sequence_mask=batch['x_mask']
        )
        return y_pred
    
    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=self.lr)
        return opt

In [84]:
MULTITASK = False
SUBSAMPLE = False
PRETRAIN = True

if not MULTITASK:
    df = pd.read_csv('./data/dbaasp.csv')[['sequence','Escherichia coli']].dropna()

else:
    df = pd.read_csv('./data/dbaasp.csv')
    df = df[(~df.isna()).sum().sort_values(ascending=False).index]
    df = df.dropna(subset=['Escherichia coli'])

# if SUBSAMPLE:
#     df = df.sample(frac=0.5,random_state=0)

output_dim = df.shape[1]-1
print(df.shape)

dm = PeptideDataModule(
    df,
    batch_size=128,
    val_batch_size=df.shape[0],
    train_val_split=0.8,
    cdhit_threshold=0.5,
    cdhit_word_length=3,
    num_workers=4
)

(2133, 2)


In [85]:
if PRETRAIN:
    [last_ckpt] = !ls -t1 ./lightning_logs/version_15815343/checkpoints/*.ckpt | head -n1
    print(last_ckpt)
    pretrained_checkpoint = last_ckpt
    state_dict = torch.load(pretrained_checkpoint,map_location=torch.device('cpu'))['state_dict']

    name = 'residue_embedding'
    residue_weights = {k.replace(name+'.',''):v for k,v in state_dict.items() if k.startswith(name)}
    
    name = 'transformer.encoder'
    encoder_weights = {k.replace(name+'.',''):v for k,v in state_dict.items() if k.startswith(name)}

./lightning_logs/version_15815343/checkpoints/epoch=23-step=7871.ckpt


In [86]:
torch.manual_seed(0)

model = PeptideTransformer(
    residues=C.alphabet,
    output_dim=output_dim,
    model_dim=64,
    model_depth=2,
    num_heads=2,
    lr=1e-4,
    dropout=0.1,
    max_length=100,
    residue_weights=residue_weights if PRETRAIN else None,
    encoder_weights=encoder_weights if PRETRAIN else None,
#     train_encoder=not PRETRAIN,
#     train_residues=not PRETRAIN
)

In [87]:
from pytorch_lightning import Trainer

!rm -rf ./lightning_logs/version_$SLURM_JOBID
trainer = Trainer(
    gpus=0,
    max_epochs=500,
    precision=32
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Epoch 128:   0%|          | 0/14 [00:16<?, ?it/s, loss=3.13, v_num=1.58e+7]

In [None]:
trainer.fit(model, dm)

In [12]:
# dm.setup()
# model = model.cpu()
# model.eval();

# from tqdm import tqdm

# y_pred = []
# y = []
# y_mask = []
# sequences = []
# for batch in tqdm(dm.val_dataloader()):
#     sequences += batch['sequence']
#     y_pred.append(model.predict_step(batch).cpu().detach().numpy())
#     y.append(batch['y'].cpu().numpy())
#     y_mask.append(batch['y_mask'].cpu().numpy())
# y_pred = np.concatenate(y_pred,0)
# y = np.concatenate(y,0)
# y_mask = np.concatenate(y_mask,0)
# y[~y_mask] = np.nan

# plt.plot(y,y_pred,'.');
# np.corrcoef(y.T,y_pred.T)