In [1]:
import lightning as L
import torch
from torch import nn

In [2]:
import sys
sys.path.append('/Users/kevinerazocastillo/Desktop/MS_SSL/SSL4MS/')

In [3]:
%env PYTORCH_ENABLE_MPS_FALLBACK=1

env: PYTORCH_ENABLE_MPS_FALLBACK=1


# Set Up Data Module for Self-Supervised Training

In [4]:
from utils.data_utils import MassSpecSelfSupervisedDataModule

In [5]:
dm = MassSpecSelfSupervisedDataModule(root_dir='/Users/kevinerazocastillo/Desktop/MS_SSL/processed_data/spectra/',
                                     train_csv='/Users/kevinerazocastillo/Desktop/MS_SSL/processed_data/labels/train_labels_split1.csv',
                                     val_csv='/Users/kevinerazocastillo/Desktop/MS_SSL/processed_data/labels/val_labels_split1.csv',
                                     test_csv='/Users/kevinerazocastillo/Desktop/MS_SSL/processed_data/labels/test_labels_split1.csv',
                                     nl_csv='/Users/kevinerazocastillo/Desktop/MS_SSL/SSL4MS/data/neutral_losses.csv',
                                     corrupt_prob=0.1,
                                     max_len=128,
                                     batch_size=64)

In [6]:
dm.setup()

# Set Up Model Module

In [7]:
from models.selfsupervised import SpectrumSymmetricAE

In [8]:
spec_AE = SpectrumSymmetricAE(num_heads=4, ffn_factor=4, dropout=0.1, hidden_dims=[8, 32, 128], max_len=128)

In [9]:
print(spec_AE)

SpectrumSymmetricAE(
  (act): Sigmoid()
  (encoder): SpectrumEncoder(
    (act): Sigmoid()
    (enc_list): ModuleList(
      (0): EncoderBlock(
        (act): Sigmoid()
        (pe): PositionalEncoding(
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (cnn): Conv1d(2, 8, kernel_size=(1,), stride=(1,), padding=same)
        (enc): EncoderLayerGLU(
          (act): Sigmoid()
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=True)
          )
          (ffn): FFN(
            (fc1): Linear(in_features=8, out_features=32, bias=True)
            (fc2): Linear(in_features=32, out_features=8, bias=True)
            (glu): GLU(
              (act): Sigmoid()
              (linear1): Linear(in_features=32, out_features=32, bias=True)
              (linear2): Linear(in_features=32, out_features=32, bias=True)
            )
          )
          (norm1): LayerNorm((8,), eps=1e-05, elementwise_

In [10]:
class LitSpectrumSymmetricAE(L.LightningModule):
    def __init__(self, symmAE):
        super().__init__()
        self.save_hyperparameters()
        self.ae = symmAE
        
    def training_step(self, batch, batch_idx):
        x_corr, x_true = batch
        x_dec, x_enc, x_mask = self.ae(x_corr)
        loss = nn.MSELoss()(x_dec, x_true)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x_corr, x_true = batch
        x_dec, x_enc, x_mask = self.ae(x_corr)
        val_loss = nn.MSELoss()(x_dec, x_true)
        self.log("val_loss", val_loss)
        
    def test_step(self, batch, batch_idx):
        x_corr, x_true = batch
        x_dec, x_enc, x_mask = self.ae(x_corr)
        test_loss = nn.MSELoss()(x_dec, x_true)
        self.log("test_loss", test_loss)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4, weight_decay=1e-3)
        return optimizer

In [11]:
lit_AE = LitSpectrumSymmetricAE(spec_AE)

/Users/kevinerazocastillo/anaconda3/envs/cheminf_MS/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'symmAE' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['symmAE'])`.


In [12]:
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

In [13]:
early_stop_callback = EarlyStopping(monitor="val_loss", patience=3, verbose=False, mode="min")

# Set Up the Trainer

In [None]:
# StochasticWeightAveraging(swa_lrs=1e-3),

In [14]:
from lightning.pytorch.callbacks import StochasticWeightAveraging

In [15]:
trainer = L.Trainer(gradient_clip_val=1.0, default_root_dir="./exp1_logs/", max_epochs=50, accelerator="mps",
                    precision='16-mixed', callbacks=[early_stop_callback])

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/kevinerazocastillo/anaconda3/envs/cheminf_MS/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [None]:
trainer.fit(lit_AE, dm)


  | Name | Type                | Params
---------------------------------------------
0 | ae   | SpectrumSymmetricAE | 829 K 
---------------------------------------------
829 K     Trainable params
0         Non-trainable params
829 K     Total params
3.320     Total estimated model params size (MB)


Sanity Checking: |                                                                                            …



Training: |                                                                                                   …



# Visualize Model

In [None]:
from torchview import draw_graph

In [None]:
model_graph = draw_graph(spec_AE, input_size=(64,128,2), expand_nested=True)
model_graph.visual_graph