In [None]:
import numpy as np
import numpy.random as npr
import pandas as pd
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from src.datamodule import MSDataModule
from src.model import MSTransformer
from src.constants import MSConstants
C = MSConstants()

In [None]:
dm = MSDataModule(
    hdf_path='./data/ProteomeTools.hdf',
    batch_size=1024,
    train_val_split=0.95,
    cdhit_threshold=0.5,
    cdhit_word_length=3,
    cache_dir='$TMPDIR',
    num_workers=20
)

In [None]:
from src.torch_helpers import start_tensorboard

start_tensorboard(login_node='login-2')

In [None]:
model = MSTransformer(
    residues=C.alphabet,
    ions=C.ions,
    parent_min_charge=C.min_charge,
    parent_max_charge=C.max_charge,
    fragment_min_charge=C.min_frag_charge,
    fragment_max_charge=C.max_frag_charge,
    losses=C.losses,
    model_dim=512,
    model_depth=4,
    num_heads=8,
    lr=1e-4,
    dropout=0.1, 
    max_length=100
)

In [None]:
from pytorch_lightning import Trainer

!rm -rf ./lightning_logs
trainer = Trainer(
    gpus=1,
    max_epochs=100,
    # num_sanity_val_steps=0,
    log_every_n_steps=5
)

In [None]:
trainer.fit(model, dm)

In [None]:
dm.setup()
model = model.cpu()
model.eval();

In [None]:
from src.plotting import faststem
from src.spectrum import fragment_mz_tensor

for i, batch in enumerate(dm.predict_dataloader(shuffle=True)):
    if i == 10:
        break
    
    batch['y_pred'] = model.predict_step(batch)

    mz = fragment_mz_tensor(batch['sequence'][0]).ravel()
    y = batch['y'][0].detach().cpu().numpy().ravel()
    y_pred = batch['y_pred'][0].detach().cpu().numpy().ravel()
    
    plt.figure(figsize=(6,3))
    faststem(mz,y)
    faststem(mz,-y_pred)
    yl = max(np.abs(plt.ylim()))
    plt.ylim([-yl,yl])
    plt.title(f"{batch['sequence'][0]} {batch['charge'][0]}+")