In [1]:
import numpy as np
import numpy.random as npr
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%config InlineBackend.figure_format = 'retina'

In [2]:
# disable for training!
# %load_ext autoreload
# %autoreload 2

In [3]:
hparams = dict(
    # datamodule
    hdf_path='./data/ProteomeTools.hdf',
    batch_size=512,
    train_split=0.85,
    val_split=0.05,
#     train_split=0.05,
#     val_split=0.01,
    cdhit_threshold=0.5,
    cdhit_word_length=3,
    tmp_env='TMPDIR',
    num_workers=8, # dont need many when loading everything into ram
    random_state=0,
    
    # model
    model_dim=128, # same size as CARP-600k
    model_depth=16,
    lr=1e-4,
    dropout=0.0, 

    # trainer
    num_gpus=1,
    max_epochs=1000,
    precision=32,
    strategy='ddp',
    es_monitor='val_cross_entropy',
    es_mode='min',
    es_patience=3,
    val_check_interval=0.1,
    
    # cluster
    num_nodes=1,
    num_cpus=20,
    conda_env='MSPretraining',
    time='0-12:00:00',
    
    # tensorboard
    login_node='login-2'
)

In [4]:
args = ' '.join([f'--{k} {v}' for k,v in hparams.items()])

slurm = f'''#!/bin/bash -l 

#SBATCH --nodes={hparams['num_nodes']}
#SBATCH --gres=gpu:volta:{hparams['num_gpus']}
#SBATCH --ntasks-per-node={max(1,hparams['num_gpus'])}
#SBATCH --cpus-per-task={hparams['num_cpus']}
#SBATCH --time={hparams['time']}

#SBATCH --signal=SIGUSR1@90

source activate {hparams['conda_env']}
''' + '''
export NCCL_DEBUG=INFO
export PYTHONFAULTHANDLER=1

# Set some environment variables needed by torch.distributed 
export MASTER_ADDR=$(hostname -s)
# Get unused port
export MASTER_PORT=$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1]); s.close()')

echo "MASTER_ADDR : ${MASTER_ADDR}"
echo "MASTER_PORT : ${MASTER_PORT}"
''' + f'''
srun python train.py {args}
'''

%store slurm >submit.sh

# !sbatch submit.sh

Writing 'slurm' (str) to file 'submit.sh'.


In [5]:
# from src.torch_helpers import start_tensorboard

# start_tensorboard(login_node=hparams['login_node'])

In [6]:
# must randomly set p_mask of the non-mask tokens to '*'
# and minimize CE on the dropped tokens

# you need ***TWO*** encoders
# one, as is, takes in sequence
# the other takes in SPECTRAL INFO
# -> 

# self.x_encoder = ByteNet(
#     n_tokens=self.input_dim,
#     d_embedding=self.embed_dim,
#     d_model=self.model_dim,
#     n_layers=self.model_depth,
#     kernel_size=self.kernel_size,
#     r=self.r,
#     padding_idx=self.padding_idx, 
#     causal=False,
#     dropout=self.dropout,
#     activation='gelu'
# )

# self.y_encoder = ByteNet(
#     n_tokens=np.prod(self.output_dim)+self.condition_dim,
#     d_embedding=self.model_dim, # I do not want a bottleneck here
#     d_model=self.model_dim,
#     n_layers=self.model_depth,
#     kernel_size=self.kernel_size,
#     r=self.r,
#     padding_idx=self.padding_idx, 
#     causal=False,
#     dropout=self.dropout,
#     activation='gelu'
# )

# self.conv1 = MaskedConv1d(self.model_dim * 2, self.model_dim)
# self.relu = nn.ReLU()
# self.conv2 = MaskedConv1d(self.model_dim, self.input_dim)

# def step(self, batch, step):
#     batch_size = batch['x'].shape[0]
#     max_length = batch['x'].shape[1]

#     x = batch['x']
#     input_mask = batch['x_mask']

#     c = torch.stack([
#         batch['charge'],
#         batch['collision_energy']
#     ],-1).float()

#     y = batch['y']
#     y = F.pad(y,[(0,0),(0,1)]+[(0,0) for _ in self.output_dim])
# #     y_mask = x_mask
# #     y_mask = y_mask.view(batch_size,max_length-1,1,1,1)
# #     y_mask = y_mask.expand_as(batch['y_mask'])

#     dropout_mask = torch.multinomial(input_mask, 1)
#     x_dropout = torch.zeros_like(input_mask)
#     x_dropout[range(batch_size),dropout_mask] = 1
#     x[x_dropout] = self.masking_idx
    
#     x_pred = self(x, y, c, input_mask)
    
#     xent = F.cross_entropy(
#         x[range(batch_size),dropout_mask],
#         x_pred[range(batch_size),dropout_mask]
#     )

# def forward(self, x, y, c, input_mask):
#     x = x.flatten(2)
#     y = y.flatten(2)
#     c = c.view(-1,1,1)
#     input_mask = input_mask.unsqueeze(-1)
    
#     x = self.x_encoder(x, input_mask=input_mask)

#     y = torch.cat([y,c],-1)
#     y = self.y_encoder(y, input_mask=input_mask)

#     z = torch.cat([x,y],-1)
#     z = self.conv1(z, input_mask=input_mask)
#     z = self.relu(z)
#     x_pred = self.conv2(z)
    
#     return x_pred

In [17]:
from torch import nn
from sequence_models.convolutional import ByteNet

self.embed_dim = 8
self.padding_idx = 0
self.masking_idx = self.input_dim + 1

self.encoder = ByteNet(
    n_tokens=self.input_dim + 1, # unused
    d_embedding=self.embed_dim + self.condition_dim + np.prod(self.output_dim),
    d_model=self.model_dim,
    n_layers=self.model_depth,
    kernel_size=self.kernel_size,
    r=self.r,
    padding_idx=0, 
    causal=False,
    dropout=self.dropout,
    activation='gelu'
)
self.encoder.embedder = nn.Identity()

self.embedder = nn.Embedding(
    num_embeddings=self.input_dim,
    embedding_dim=self.embed_dim,
    padding_idx=self.padding_idx
)

def forward(self, x, x = self.embedder(x)

In [None]:
# encoder = ByteNet(...)
# embedder = nn.Embedding(...)
# inner_embedder = nn.Linear(...)

# # first randomly mask
# x_dropout = torch.rand_like(x) < self.masking_rate
# x[x_dropout] = self.masking_idx

# x = self.embedder(x)
# y = torch.cat([y.flatten(2),
# x = torch.cat([x,y,c],-1)
# x = self.inner_embedder(x)
# x = self.encoder(x)
# x = self.classifier(x)

In [7]:
from src.datamodule import MSDataModule
from src.model import MSTransformer

model = MSTransformer(**hparams)
dm = MSDataModule(**dict(model.hparams))

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from src.torch_helpers import NoValProgressBar

seed_everything(hparams['random_state'], workers=True)

!rm -rf ./lightning_logs/version_$SLURM_JOBID

trainer = Trainer(
    gpus=1,
    precision=hparams['precision'],
    val_check_interval=hparams['val_check_interval'],
    max_epochs=100,
    callbacks=[
        EarlyStopping(
            monitor=hparams['es_monitor'],
            mode=hparams['es_mode'],
            patience=hparams['es_patience']
        ),
        NoValProgressBar(),
        # checkpoitn the best so far
    ]
)

trainer.fit(model, dm)

# rename it

Global seed set to 0
Multiprocessing is handled by SLURM.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [GPU-a15c92af-b942-6540-1a2c-4509456ac8f1]

  | Name       | Type      | Params
-----------------------------------------
0 | encoder    | ByteNetLM | 603 K 
1 | decoder    | MSDecoder | 52.2 K
2 | classifier | Linear    | 3.1 K 
-----------------------------------------
658 K     Trainable params
0         Non-trainable params
658 K     Total params
2.635     Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Epoch 0: 100%|██████████| 7836/7836 [06:27<00:00, 20.22it/s, loss=2.92, v_num=1.79e+7]


In [9]:
# from tqdm import tqdm
# train_seqs = {item['sequence'] for item,_ in tqdm(zip(dm.train_dataset,range(1000)),position=0)}
# val_seqs = {item['sequence'] for item,_ in tqdm(zip(dm.val_dataset,range(1000)),position=0)}
# test_seqs = {item['sequence'] for item,_ in tqdm(zip(dm.test_dataset,range(1000)),position=0)}
# train_seqs&val_seqs, train_seqs&test_seqs, test_seqs&val_seqs

In [10]:
# from src.datamodule import MSDataModule
# from src.model import MSTransformer
# from src.plotting import faststem
# from src.spectrum import fragment_mz_tensor
# from tqdm import tqdm

# # [last_ckpt] = !ls -t1 ./lightning_logs/*/checkpoints/*.ckpt | head -n1
# # print(last_ckpt)
# # model = MSTransformer.load_from_checkpoint(last_ckpt)
# # dm = MSDataModule(**dict(model.hparams))

# dm.setup()

# model = model.cpu()
# model.eval();

# for i, batch in enumerate(dm.predict_dataloader()):
#     batch['y_pred'] = model.predict_step(batch)

#     mz = fragment_mz_tensor(batch['sequence'][0]).ravel()
#     y = batch['y'][0].detach().cpu().numpy().ravel()
#     y_pred = batch['y_pred'][0].detach().cpu().numpy().ravel()
    
#     plt.figure(figsize=(6,3))
#     faststem(mz,y)
#     faststem(mz,-y_pred)
#     yl = max(np.abs(plt.ylim()))
#     plt.ylim([-yl,yl])
#     plt.title(f"{batch['sequence'][0]} {batch['charge'][0]}+")
    
#     if i == 10:
#         break

In [11]:
# the CNN does not shrink unseen peaks to zero, while the transformer does