**Note: this notebook will only work for pst versions >1.1.0**

# Finetuning with a new genome-level objective
The vPST was pretrained with a triplet loss objective, evaluating the genome embeddings.

If you want to apply the vPST to a new objective (transfer learning), then you need to subclass the `BaseProteinSetTransformer` module class and update the following methods:

1. `forward` code needed to handle a minibatch and compute the loss
2. `setup_objective` code needed to create a callable that computes the loss directly. This code is called upon initialization of the model, and the `forward` method calls the `.criterion` callable that is returned by this method.

Additionally, if the loss function maintains state (such as the margin and scaling values of a triplet loss objective), then you can create a subclass of the `BaseModelConfig` with the loss field using a custom subclass of the `BaseLossConfig` that specifies the name and default values of stateful parameters needed by the loss function. This is only necessary for tunable hyperparameters of the loss function, NOT just any arguments needed to setup the loss function callable.

-----

Let's look at an example where we want to predict some random binary feature about the genomes in the sample dataset provided. For demonstration purposes, we will suppose that we have some tunable weight required for the loss function.

In [1]:
from pst import BaseProteinSetTransformer as BasePST
from pst import GenomeDataModule, BaseLossConfig, BaseModelConfig, GenomeGraphBatch

import lightning as L
import torch
from pydantic import Field

L.seed_everything(111)

Seed set to 111


111

Since we are changing the objective of our new model that is derived from a pretrained PST, we need to define:

1. A custom loss config model that subclasses `BaseLossConfig` IF the loss function requires a tunable state
2. A custom model config model is a a subclass of `BaseModelConfig` if any subfields need to be changed. The fields of this config model are available to the class through the `self.config` attribute.
3. A custom loss `torch.nn.Module` or function that computes the loss given the outputs of the model's forward pass and any expected targets, if any

In [2]:
class CustomLossConfig(BaseLossConfig):
    tunable_weight: float = Field(0.5, ge=0.0, le=1.0, description="some tunable weight")

class CustomModelConfig(BaseModelConfig):
    loss: CustomLossConfig

class CustomLossFn(torch.nn.Module):
    def __init__(self, weight: float):
        super().__init__()
        self.weight = weight # just an example, idk why you would use this
        self.fn = torch.nn.BCEWithLogitsLoss()

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        loss = self.fn(y_pred, y_true) * self.weight
        return loss

Now you can create a subclass of the `BaseProteinSetTransformer` model that redefines 3 methods:

1. In the `__init__` method, add any new layers and other attributes that are required by the model. Generally, there should only be 1 argument to the `__init__` method called "config". Thus, any additional attributes should be added as fields to a custom-defined config model.
2. `setup_objective` returns a callable that is used to compute the loss during a forward pass. It receives all values of the `self.config.loss` model as keyword arguments. The return value of this function is stored in `self.criterion`.
3. `forward` to define how data from a minibatch is handled to subsequently compute the loss using `self.criterion`

In [3]:
class CustomGenomeLevelPST(BasePST[CustomModelConfig]): # <- this is optional to specify the config type here, but enables IDEs to provide better autocompletion
    def __init__(self, config: CustomModelConfig):
        super().__init__(config)

        # define new layers for new objective
        self.pred_layer = torch.nn.Linear(self.config.out_dim, 1)

    def setup_objective(self, tunable_weight: float, **kwargs) -> CustomLossFn:
        # notice how the var name is the same as in the CustomLossConfig -- those fields get passed
        # as keyword arguments to this method
        return CustomLossFn(tunable_weight)

    def forward(self, batch: GenomeGraphBatch, stage: str, **kwargs):
        # add strand/pos embeddings
        x_cat, _, _ = self.internal_embeddings(batch)

        pst_output, _ = self.databatch_forward(batch=batch, x=x_cat)

        y_pred = self.pred_layer(pst_output).squeeze()
        y_true = batch.y

        loss = self.criterion(y_pred, y_true)

        self.log_loss(loss, batch.num_proteins.numel(), stage)

        return loss

Now that we have a custom model defined, let's see an extremely trivial example. In the sample dataset provided, there are 8 genomes that we will randomly generate a binary label for.

Then we will use our model's loss function (which is primarily just a binary cross entropy loss for binary classification) to train this model with help from `lightning.Trainer`.

Let's start by loading the sample dataset:

In [4]:
ckptfile = "pst-small_trained_model.ckpt"
data_file = "sample_dataset.graphfmt.h5"
datamodule = GenomeDataModule.from_pretrained(
    checkpoint_path=ckptfile, data_file=data_file, shuffle=False
)

Now we want to add a `y` field to our dataset that contains our randomly generated labels. NOTE: We store this in a `y` field since our model's `forward` method refers to the `y` attribute of the minibatch object (`batch.y`)

Here is how we can register new dataset attributes using the `GenomeDataModule.register_feature` method.

In [5]:
dataset = datamodule.dataset
n_genomes = len(dataset)
# randomly generated genome level labels
y_true = (torch.rand(n_genomes) >= 0.5).float()

datamodule.register_feature("y", y_true, feature_level="genome")

Then all you need to do is train your model!

In [6]:
model = CustomGenomeLevelPST.from_pretrained(ckptfile)

# disable checkpointing and logging for this demo
trainer = L.Trainer(max_epochs=25, enable_checkpointing=False, logger=False)

trainer.fit(model, datamodule=datamodule)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/scratch/ccmartin6/miniconda3/envs/pst/lib/python3.10/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
Loading `train_dataloader` to estimate number of stepping batches.
/scratch/ccmartin6/miniconda3/envs/pst/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.

  | Name                 | Type                | Params
-------------------------------------------------------------
0 | positional_embedding | PositionalEmbedding | 81.9 K
1 | strand_embedding     | Embedding           | 80    
2 | model         

Training: |          | 0/? [00:00<?, ?it/s]

/scratch/ccmartin6/miniconda3/envs/pst/lib/python3.10/site-packages/lightning/pytorch/core/module.py:507: You called `self.log('train_loss', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
`Trainer.fit` stopped: `max_epochs=25` reached.


Just for confirmation since this is a really simple demo that surely creates a model that overfits to our sample data, here is the accuracy on the training dataset:

In [7]:
datamodule.setup("predict")
batch = next(iter(datamodule.predict_dataloader()))

with torch.no_grad():
    model.eval()
    x_cat, _, _ = model.internal_embeddings(batch)
    pst_encoder_output, _ = model.databatch_forward(batch=batch, x=x_cat)
    y_pred = model.pred_layer(pst_encoder_output).squeeze()

prob = torch.sigmoid(y_pred)
pred = (prob >= 0.5).float()

# accuracy
(pred == batch.y).sum() / pred.size(0)

tensor(1.)

# Finetuning with a new protein-level objective
The vPST was pretrained with genome-level objective. However, it internally computes contextualized protein embeddings using genome context.

If you want to focus more on these protein embeddings rather than genome embeddings, such as for a protein prediction task or even pretraining a protein foundation model, then you need to create a subclass of the `BaseProteinSetTransformerEncoder` module and update the following methods:

1. `forward` code needed to handle a minibatch and compute the loss
2. `setup_objective` code needed to create a callable that computes the loss directly. This code is called upon initialization of the model, and the `forward` method calls the `.criterion` callable that is returned by this method.

Additionally, if the loss function maintains state (such as the margin and scaling values of a triplet loss objective), then you can create a subclass of the `BaseModelConfig` with the loss field using a custom subclass of the `BaseLossConfig` that specifies the name and default values of stateful parameters needed by the loss function. This is only necessary for tunable hyperparameters of the loss function, NOT just any arguments needed to setup the loss function callable.

NOTE: This is pretty much identical as the genome-level objective change above. The ONLY difference is that you need to subclass a `BaseProteinSetTransformerEncoder` class instead of `BaseProteinSetTransformer`.

-----

Let's look at an example where we want to predict some random binary feature about the genomes in the sample dataset provided. For demonstration purposes, we will suppose that we have some tunable weight required for the loss function.

In [8]:
from pst import BaseProteinSetTransformerEncoder as BasePSTEncoder

We are just reusing the loss function and custom model config defined in the genome-level demo to compute binary cross entropy loss for a randomly generated protein-level label.

In [9]:
class CustomProteinLevelPST(BasePSTEncoder[CustomModelConfig]): # <- again note the optional config type hint here
    def __init__(self, config: CustomModelConfig):
        super().__init__(config)

        # define new layers for new objective
        self.pred_layer = torch.nn.Linear(self.config.out_dim, 1)

    def setup_objective(self, tunable_weight: float, **kwargs) -> CustomLossFn:
        return CustomLossFn(tunable_weight)
    
    def forward(self, batch: GenomeGraphBatch, stage: str, **kwargs):
        # intentionally left this nearly identical to the previous example

        # add strand/pos embeddings
        x_cat, _, _ = self.internal_embeddings(batch)

        pst_encoder_output, _, _ = self.databatch_forward(batch=batch, x=x_cat)

        y_pred = self.pred_layer(pst_encoder_output).squeeze()
        y_true = batch.y

        loss = self.criterion(y_pred, y_true)

        self.log_loss(loss, batch.num_proteins.numel(), stage)

        return loss

We already loaded the datamodule previously, so we just need to register the randomly created protein labels.

In [10]:
n_proteins = dataset.data.shape[0]
y_true = (torch.rand(n_proteins) >= 0.5).float()
dataset.register_feature("y", y_true, feature_level="protein", overwrite_previously_registered=True)

Then just train the model using the `lightning.Trainer`!

In [11]:
model = CustomProteinLevelPST.from_pretrained(ckptfile)

# disable checkpointing and logging for this demo
trainer = L.Trainer(max_epochs=25, enable_checkpointing=False, logger=False)

trainer.fit(model, datamodule=datamodule)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Loading `train_dataloader` to estimate number of stepping batches.

  | Name                 | Type                  | Params
---------------------------------------------------------------
0 | positional_embedding | PositionalEmbedding   | 81.9 K
1 | strand_embedding     | Embedding             | 80    
2 | model                | SetTransformerEncoder | 4.0 M 
3 | criterion            | CustomLossFn          | 0     
4 | pred_layer           | Linear                | 401   
---------------------------------------------------------------
4.1 M     Trainable params
0         Non-trainable params
4.1 M     Total params
16.422    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=25` reached.


Again for confirmation, we checked the prediction accuracy of the training set.

In [12]:
datamodule.setup("predict")
batch = next(iter(datamodule.predict_dataloader()))

with torch.no_grad():
    model.eval()
    x_cat, _, _ = model.internal_embeddings(batch)
    pst_encoder_output, _, _ = model.databatch_forward(batch=batch, x=x_cat)
    y_pred = model.pred_layer(pst_encoder_output).squeeze()

prob = torch.sigmoid(y_pred)
pred = (prob >= 0.5).float()

# accuracy
(pred == batch.y).sum() / pred.size(0)

tensor(0.9730)