In [None]:
%load_ext autoreload
%autoreload 2

# Callbacks

> Library-wide callbacks used within the BLURR library.


In [None]:
# |default_exp callbacks
# |default_cls_lvl 3

In [None]:
# |export
from __future__ import annotations

import os

from dotenv import load_dotenv
from fastcore.all import *
from fastai.callback.all import *
from fastai.imports import *
from fastai.learner import *
from fastai.torch_core import *
import torch
from transformers import PreTrainedModel

In [None]:
# |hide
import pdb

from fastcore.test import *
import nbdev

In [None]:
# |export
os.environ["TOKENIZERS_PARALLELISM"] = "false"
load_dotenv()

False

In [None]:
# |hide
# |cuda
torch.cuda.set_device(1)
print(f"Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}")

Using GPU #1: NVIDIA GeForce RTX 3090


## Gradient Checkpointing

In [None]:
# |export
class CheckpointingNotSupported(Exception):
    def __init__(self, msg="Model does not support gradient checkpointing."):
        super().__init__(msg)

In [None]:
# |export
class GradientCheckpointing(Callback):
    """A fastai callback to enable gradient checkpointing for compatible HuggingFace models."""

    def before_fit(self):
        """Enable gradient checkpointing on before_fit event."""

        # Check that huggingface model supports gradient checkpointing
        if not self.model.hf_model.supports_gradient_checkpointing:
            raise CheckpointingNotSupported()

        if self.model.hf_model.is_gradient_checkpointing == False:
            self.model.hf_model.gradient_checkpointing_enable()

    def after_fit(self):
        """Disable gradient checkpointing on after_fit event."""
        if self.model.hf_model.is_gradient_checkpointing:
            self.model.hf_model.gradient_checkpointing_disable()

    @staticmethod
    def supported(model: PreTrainedModel):
        """Tests whether a HuggingFace `PreTrainedModel` supports gradient checkpointing."""
        return model.supports_gradient_checkpointing

## Export -

In [None]:
# | hide
import nbdev

nbdev.nbdev_export()