In [1]:
%reload_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

## Transfer Learning

Generally, transfer learning refers to the process of leveraging the knowledge learned in one model for better performance on a different task. A task is a vague term, but it essentially includes learning a different objective, for example, transitioning from regression to classification. It can also involve learning the same objective with a different loss function or optimizer, or using the same loss and objective but with different data. In cases where the dataset is too small to train a model from scratch without overfitting, we start from a pretrained model that has good performance on a larger dataset. This last type is the transfer learning we are using in alphaDIA.

## Transfer Learning for alphaDIA

In alphaDIA, we use the output obtained after the first search run and introduce a signal from the identified peptides to fine-tune the predictions from PeptDeep models on the custom dataset we are interested in. We then explore how this affects a second search run with fine-tuned models used to predict the spectral library.

Training neural networks, and specifically transformers (such as those used in MS2 prediction), usually requires a lot of hyperparameter fine-tuning. Users try out different parameters like learning rate and the number of epochs to see what works better. For users with limited experience, this may seem like a trial and error process that is very time-consuming. The goal of the transfer learning module in alphaDIA is to provide robustness with minimal intervention from users, thereby increasing the accessibility of such processes for users from all backgrounds with minimal experience in deep learning.


In this notebook, we will be going over the transfer learning done in alphaDIA, starting with two components that help achieve the robustness we are targeting: **Learning Rate Schedulers** and **Early Stopping**. If you understand these concepts and want to jump straight into how to use our APIs to fine-tune the model, skip to the [finetuning section](#transfer-learning-in-alphadia)


In [2]:

from alphadia.transferlearning.train import CustomScheduler


## Learning Rate Scheduler

Learning rates are crucial parameters that define the magnitude of updates made to the model weights, essentially controlling "how fast we learn". While a higher learning rate might seem beneficial, it can cause the weights to converge quickly to sub-optimal values and oscillate around them. If the learning rate is too high, it can even cause the model to diverge by overshooting or over-correcting the weights. This is where learning rate schedulers come into play. A learning rate scheduler adjusts the learning rate of a neural network (or part of it) dynamically based on time/epochs or the loss of the model (more on that later).

For alphaDIA, we use a custom learning rate scheduler with two phases:

### 1) Warmup Phase
In this phase, the learning rate starts small and gradually increases over a certain number of "warmup epochs". Our default is **5**. This technique significantly helps in training transformers when using optimizers like Adam or SGD ([https://arxiv.org/abs/2002.04745](https://arxiv.org/abs/2002.04745)). Since we are not training from scratch, we set the default number of warmup epochs to 5. The user only needs to define the maximum learning rate and the number of epochs for warm-up. During this phase, the learning rate lr(t) is calculated as:

$$
\text{lr}(t) = \text{max\_lr} \times \left( \frac{t}{\text{number of warmup epochs}} \right)
$$

### 2) Reduce on Plateau LR Schedule
After the warmup phase, the learning rate reaches the maximum value set by the user and remains there until the training loss reaches a plateau. A plateau is defined as the training loss not significantly improving for a certain number of epochs, referred to as "patience". For this phase, we use the PyTorch implementation `torch.optim.lr_scheduler.ReduceLROnPlateau` with a default patience value of 3 epochs. 

This approach makes the fine-tuning process less sensitive to the user-defined learning rate. If the model is not learning for 3 epochs, it is likely that the learning rate is too high, and the scheduler will then reduce the learning rate to encourage further learning.



In [3]:
"""
To show how our LR scheduler works, we will use a dummy optimizer with a dummy model parameters.
"""

import torch

NUM_WARMUP_STEPS = 5
MAX_LR = 0.01

class DummyModel(torch.nn.Module):
    def __init__(self):
        super(DummyModel, self).__init__()
        self.fc = torch.nn.Linear(1, 1)
    
    def forward(self, x):
        return self.fc(x)
    
model = DummyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=MAX_LR)


In [None]:
"""
Now since our lr scheduler uses reduce_lr_on_plateau, we need to pass the training loss per each epoch to the scheduler.
"""

# Dummy training loss
losses = [0.12,0.1, 0.09, 0.08, 0.07, 0.06, 0.06,0.06,0.06,0.06,0.06, 0.06] 
scheduler = CustomScheduler(optimizer, max_lr=MAX_LR, num_warmup_steps=NUM_WARMUP_STEPS)

learning_rates = []
for epoch, loss in enumerate(losses):
    scheduler.step(epoch, loss)
    learning_rates.append(optimizer.param_groups[0]['lr'])
    print(f"Epoch {epoch+1}, Loss: {loss}, LR: {learning_rates[-1]}")


fig, ax1 = plt.subplots()

color = 'tab:red'
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Learning Rate', color=color)
ax1.plot(learning_rates, color=color)
ax1.tick_params(axis='y', labelcolor=color)

ax2 = ax1.twinx()  

color = 'tab:blue'
ax2.set_ylabel('Loss', color=color)  
ax2.plot(losses, color=color)
ax2.tick_params(axis='y', labelcolor=color)

fig.tight_layout() 
plt.show()

    


Notice how in the first 5 epochs the learning rate started from 
$\frac{\text{max\_lr}}{\text{number of epochs}} = \frac{0.01}{5} = 0.002$ 
and increased till it reached $0.01$ in the 5th epoch.

When the loss plateaus for more than 3 epochs (the value set for patience), the learning rate is halved. We will see how much this learning rate halving actually helps retention time (RT) and MS2 fine-tuning to consistently achieve much better performance without requiring extensive experimentation with hyperparameter changes.



In [5]:
from alphadia.transferlearning.train import EarlyStopping

## Early Stopping

With the learning scheduler we are using, we could theoretically keep training indefinitely, as the learning rate is reduced whenever the loss becomes steady until it reaches an infinitesimally small value. However, there are two issues with this approach:

1. The performance gains when the learning rate is very small are often not significant enough to justify continued training.
2. Longer training times increase the risk of overfitting on the small dataset we are fine-tuning on.

To address these issues, we implement two measures:

### 1) Maximum Number of Epochs
We set the maximum number of epochs to 50. From our experiments, we find that 50 epochs are usually sufficient to achieve significant performance gains without spending unnecessary time/epochs on insignificant improvements.

### 2) Early Stopping
We use an Early Stopping implementation that monitors the validation loss and terminates the training if one of the following criteria is met for more than the patience epochs (this is different from the learning rate scheduler's patience value, but they are related, more on this later):

a) The validation loss is increasing, which may indicate overfitting.

b) The validation loss is not significantly improving, indicating no significant performance gains on the validation dataset.

The early stopping patience value represents the number of epochs we allow the model to meet the criteria without taking any action. This is because training neural networks with smaller batches can be a bit unstable, so we allow for some volatility before intervening. We set the early stopping patience to be a multiple of the learning rate scheduler patience. The idea is to give the learning rate scheduler a chance to address the problem before terminating the training.

It's important to note that there are many implementations of Early Stopping algorithms, some offering better performance against overfitting by monitoring the generalization gap (val_loss - train_loss). However, we found that the simple implementation we use is sufficient for our fine-tuning tasks.


In [None]:
"""
To illustrate how our early stopping works we will try it on simulated validation losses in differnet cases and see how and when the early stopping is triggered.
"""
simulated_losses = {
    "diverging": [0.5, 0.3 ,0.2, 0.1, 0.125, 0.15, 0.2, 0.3, 0.5, 0.7, 1.0],
    "converging": [0.5, 0.3 ,0.2, 0.1, 0.05, 0.03, 0.02, 0.01, 0.005, 0.002, 0.0005],
    "not significantly improving" : [0.5, 0.3 ,0.2, 0.1, 0.07, 0.0695, 0.0689, 0.06883, 0.06878,0.06874, 0.06869],
}

stopped_at = {case: len(losses)-1 for case, losses in simulated_losses.items()}
for case, losses in simulated_losses.items():
    early_stopping = EarlyStopping(patience=3)
    for epoch, loss in enumerate(losses):
        continue_training = early_stopping.step(loss)
        if not continue_training:
            stopped_at[case] = epoch
            break

fig, ax = plt.subplots()
for case, losses in simulated_losses.items():
    ax.plot(losses, label="Loss "+case)
    ax.scatter(stopped_at[case], losses[stopped_at[case]], color='red')
ax.legend()
ax.set_xlabel("Epoch")
ax.set_ylabel("Validation Loss")
plt.show()


## Transfer Learning in alphaDIA

Finally, we have arrived at the interesting part: fine-tuning the prediction models to achieve much better search results. Before we show how to use our `FinetuneManager` class, there is one simple thing to note. The data we fine-tune on is collected from the first search, and it's already part of the pipeline. If you have completed the search, it will be located under `"workspace_dir/output/speclib_transfer.hdf"`. For this tutorial, we will download a pre-accumulated dataset.


In [None]:
!pip install gdown

import gdown
speclib_url = "https://drive.google.com/uc?id=16uE07CiT2Oz76pfTVJILWZK7cj9VfaCm"
gdown.download(speclib_url, 'speclib_transfer.hdf', quiet=False)
path_to_speclib = "speclib_transfer.hdf"



In [None]:
from alphabase.spectral_library.base import SpecLibBase
import os

transfer_lib = SpecLibBase()
transfer_lib.load_hdf(path_to_speclib, load_mod_seq=True)
os.remove(path_to_speclib)
transfer_lib.precursor_df.head()

## Transfer Learning in alphaDIA

Okay, one last thing to know before delving into it. Our fine-tuning interface in the `FinetuneManager` has a method implemented for each model (RT, Charge, MS2). Each function fine-tunes the respective model and runs tests on the validation dataset every epoch. By the end of the fine-tuning process, the method returns a pandas.dataframe containing all metrics accumulated through fine-tuning over the epochs, namely the training loss, the learning rate, and all the test metrics for the respective model.

The test metrics are calculated as the average over all validation samples and are as follows:

| Model   | Metrics |
|---------|---------|
| **RT**  | L1 loss, Linear regression analysis, Absolute Error 95th percentile |
| **Charge** | Cross Entropy, Accuracy, Precision, Recall |
| **MS2**    | L1 loss, Pearson Correlation Coefficient, Cosine Similarity, Spectral Angle, Spearman Correlation |

Let's prepare a simple function that will plot these statistics once the fine-tuning finishes.


In [9]:
def plot_stats(model,stats):
    
    stats = stats[1:-1]
    num_plots = len(stats.columns) - 3
    fig_col = 2
    fig_row = np.ceil(num_plots / fig_col).astype(int) + 1
    fig, ax = plt.subplots(fig_row, fig_col, figsize=(15, 5 * fig_row))

    fig.suptitle(f"Fine-tuning {model}", fontsize=25)

    x_axis = stats["epoch"]
    # Train and test loss
    ax[0, 0].plot(x_axis, stats["train_loss"], label="train")
    ax[0, 0].plot(x_axis, stats["test_loss"], label="test")
    ax[0, 0].set_title("Loss")
    ax[0, 0].set_xlabel("Epoch")
    ax[0, 0].set_ylabel("Loss")
    ax[0, 0].legend()


    # Rest of the columns
    columns_to_plot = stats.columns.drop(["epoch", "train_loss", "test_loss"])
    for i, column_name in enumerate(columns_to_plot):
        row = (i + 1) // fig_col
        col = (i + 1) % fig_col
        ax[row, col].plot(x_axis, stats[column_name])
        ax[row, col].set_title(column_name)
        ax[row, col].set_xlabel("Epoch")
        ax[row, col].set_ylabel(column_name)


    plt.tight_layout()
    plt.show()

In [10]:
from alphadia.transferlearning.train import FinetuneManager, settings
tune_mgr = FinetuneManager(
    device="gpu",
    settings=settings)
tune_mgr.nce = 25
tune_mgr.instrument = 'Lumos'

## RT Fine-tuning


In [None]:

transfer_lib.precursor_df = tune_mgr.predict_rt(transfer_lib.precursor_df)
plt.scatter(transfer_lib.precursor_df['rt_norm'], transfer_lib.precursor_df['rt_norm_pred'], s=1, alpha=0.1)
plt.title('RT prediction before fine-tuning')
plt.xlabel('RT observed')
plt.ylabel('RT predicted')

In [None]:
rt_stats = tune_mgr.finetune_rt(transfer_lib.precursor_df)

In [None]:
transfer_lib.precursor_df = tune_mgr.predict_rt(transfer_lib.precursor_df)
plt.scatter(transfer_lib.precursor_df['rt_norm'], transfer_lib.precursor_df['rt_norm_pred'], s=1, alpha=0.1)
plt.title('RT prediction after fine-tuning')
plt.xlabel('RT observed')
plt.ylabel('RT predicted')

In [None]:
plot_stats("RT", rt_stats)

## Charge Fine-tuning

In [None]:
charge_stats = tune_mgr.finetune_charge(psm_df=transfer_lib.precursor_df)

In [None]:
plot_stats("Charge", charge_stats)

## MS2 Fine-tuning

In [18]:

transfer_lib.precursor_df = transfer_lib.precursor_df[transfer_lib.precursor_df['use_for_ms2']]


In [19]:
def calculate_similarity(precursor_df_a, precursor_df_b, intensity_df_a, intensity_df_b):

    _a_df = precursor_df_a[['precursor_idx', 'frag_start_idx', 'frag_stop_idx']].copy()
    _b_df = precursor_df_b[['precursor_idx', 'frag_start_idx', 'frag_stop_idx']].copy()

    _merged_df = pd.merge(_a_df, _b_df, on='precursor_idx', suffixes=('_a', '_b'))
    # keep only first precursor
    _merged_df = _merged_df.drop_duplicates(subset='precursor_idx', keep='first')
    similarity_list = []

    for i, (start_a, stop_a, start_b, stop_b) in enumerate(zip(_merged_df['frag_start_idx_a'], _merged_df['frag_stop_idx_a'], _merged_df['frag_start_idx_b'], _merged_df['frag_stop_idx_b'])):
        observed_intensity = intensity_df_a.iloc[start_a:stop_a, :4].values.flatten()
        predicted_intensity = intensity_df_b.iloc[start_b:stop_b, :4].values.flatten()

        similarity = np.dot(observed_intensity, predicted_intensity) / (np.linalg.norm(observed_intensity) * np.linalg.norm(predicted_intensity))
        similarity_list.append({'similarity': similarity, 'index': i, 'precursor_idx': _merged_df.iloc[i]['precursor_idx']})

    return pd.DataFrame(similarity_list)

In [None]:
res = tune_mgr.predict_all(transfer_lib.precursor_df.copy(), predict_items=['ms2'])

precursor_after_df = res['precursor_df']
fragment_mz_after_df = res['fragment_mz_df']
fragment_intensity_after_df = res['fragment_intensity_df']
similarity_after_df = calculate_similarity(precursor_after_df, transfer_lib.precursor_df, fragment_intensity_after_df, transfer_lib.fragment_intensity_df)
print(similarity_after_df['similarity'].median())
plt.scatter(similarity_after_df['index'], similarity_after_df['similarity'], s=0.1)
plt.title('Similarity before fine-tuning')
plt.ylabel('Mean Similarity')

In [None]:

ms2_stats = tune_mgr.finetune_ms2(psm_df=transfer_lib.precursor_df.copy(), matched_intensity_df=transfer_lib.fragment_intensity_df.copy())

In [None]:
res = tune_mgr.predict_all(transfer_lib.precursor_df.copy(), predict_items=['ms2'])

precursor_after_df = res['precursor_df']
fragment_mz_after_df = res['fragment_mz_df']
fragment_intensity_after_df = res['fragment_intensity_df']
similarity_after_df = calculate_similarity(precursor_after_df, transfer_lib.precursor_df, fragment_intensity_after_df, transfer_lib.fragment_intensity_df)
print(similarity_after_df['similarity'].median())
plt.scatter(similarity_after_df['index'], similarity_after_df['similarity'], s=0.1)
plt.title('Similarity after fine-tuning')
plt.ylabel('Mean Similarity')

In [None]:
plot_stats("MS2", ms2_stats)