## Lightining IR Create Custom Cross Encoder

In [18]:
import warnings

warnings.filterwarnings("ignore")

In [19]:
import torch
from torch.optim import AdamW
from transformers import AutoConfig, AutoModel, AutoTokenizer, BatchEncoding

from lightning_ir import (
    CrossEncoderModel,
    CrossEncoderModule,
    CrossEncoderOutput,
    CrossEncoderTokenizer,
    LightningIRDataModule,
    LightningIRTrainer,
    RankNet,
    TupleDataset,
)
from lightning_ir.cross_encoder.config import CrossEncoderConfig

In [20]:
device_str = "cuda" if torch.cuda.is_available() else "cpu"

In [21]:
class DemoCrossEncoderConfig(CrossEncoderConfig):
    model_type = "custom-cross-encoder"

    ADDED_ARGS = CrossEncoderConfig.ADDED_ARGS.union({"additional_linear_layer"})

    def __init__(self, additional_linear_layer: bool = True, **kwargs):
        super().__init__(**kwargs)
        self.additional_linear_layer = additional_linear_layer


class DemoCrossEncoderModel(CrossEncoderModel):
    config_class = DemoCrossEncoderConfig

    def __init__(self, config: DemoCrossEncoderConfig, *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.additional_linear_layer = None
        if config.additional_linear_layer:
            self.additional_linear_layer = torch.nn.Linear(
                config.hidden_size, config.hidden_size
            )

    def forward(self, encoding: BatchEncoding) -> torch.Tensor:
        embeddings = self._backbone_forward(**encoding).last_hidden_state
        embeddings = self._pooling(
            embeddings,
            encoding.get("attention_mask", None),
            pooling_strategy=self.config.pooling_strategy,
        )
        if self.additional_linear_layer is not None:
            embeddings = self.additional_linear_layer(embeddings)
        scores = self.linear(embeddings).view(-1)
        return CrossEncoderOutput(scores=scores, embeddings=embeddings)

In [22]:
AutoConfig.register(DemoCrossEncoderConfig.model_type, DemoCrossEncoderConfig)
AutoModel.register(DemoCrossEncoderConfig, DemoCrossEncoderModel)
AutoTokenizer.register(DemoCrossEncoderConfig, CrossEncoderTokenizer)

In [23]:
module = CrossEncoderModule(
    model_name_or_path="bert-base-uncased",
    config=DemoCrossEncoderConfig(),  # our custom config
    loss_functions=[RankNet()],
)
module.set_optimizer(AdamW, lr=1e-5)
data_module = LightningIRDataModule(
    train_dataset=TupleDataset("msmarco-passage/train/triples-small"),
    train_batch_size=2,
)

You are using a model of type bert to instantiate a model of type bert-custom-cross-encoder. This is not supported for all configurations of models and can yield errors.
Some weights of CustomCrossEncoderBertModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.additional_linear_layer.bias', 'bert.additional_linear_layer.weight', 'bert.linear.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BertTokenizer'. 
The class this function is called from is 'CustomCrossEncoderBertTokenizerFast'.


In [29]:
trainer = LightningIRTrainer(max_steps=100, max_epochs=3, accelerator=device_str)
trainer.fit(module, data_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                        | Params | Mode 
--------------------------------------------------------------
0 | model | CustomCrossEncoderBertModel | 109 M  | train
--------------------------------------------------------------
109 M     Trainable params
0         Non-trainable params
109 M     Total params
437.932   Total estimated model params size (MB)


Epoch 0: |          | 100/? [00:35<00:00,  2.83it/s, v_num=15, loss=0.0107]

`Trainer.fit` stopped: `max_steps=100` reached.


Epoch 0: |          | 100/? [01:04<00:00,  1.55it/s, v_num=15, loss=0.0107]


In [34]:
trainer.save_checkpoint("custom-cross-encoder.ckpt")

In [30]:
import gc

torch.cuda.empty_cache()
gc.collect()

312

In [33]:
help(trainer.save_checkpoint)

Help on method save_checkpoint in module lightning.pytorch.trainer.trainer:

save_checkpoint(filepath: Union[str, pathlib.Path], weights_only: bool = False, storage_options: Optional[Any] = None) -> None method of lightning_ir.main.LightningIRTrainer instance
    Runs routine to create a checkpoint.
    
    This method needs to be called on all processes in case the selected strategy is handling distributed
    checkpointing.
    
    Args:
        filepath: Path where checkpoint is saved.
        weights_only: If ``True``, will only save the model weights.
        storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin
    
    Raises:
        AttributeError:
            If the model is not attached to the Trainer before calling this method.

