In [None]:
# default_exp callbacks
# default_cls_lvl 3

In [None]:
#hide
%reload_ext autoreload
%autoreload 2
%matplotlib inline

# callbacks

> `Callback`s used by the `BLURR` library.

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
# export
import importlib, sys, torch
from typing import Any, Callable, Dict, List, Optional, Union, Type

from fastcore.all import *
from fastai.callback.all import *
from fastai.imports import *
from fastai.learner import *
from fastai.torch_core import *
from transformers import PreTrainedModel

In [None]:
# hide_input
import pdb

from IPython.display import display
from fastcore.test import *
from nbdev.showdoc import show_doc


In [None]:
# hide
# cuda
torch.cuda.set_device(1)
print(f"Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}")


Using GPU #1: NVIDIA GeForce RTX 3080


## Gradient Checkpointing

In [None]:
# exporti
class CheckpointingNotSupported(Exception):
    def __init__(self, msg="Model does not support gradient checkpointing."):
        super().__init__(msg)

In [None]:
# export
class GradientCheckpointing(Callback):
    """A fastai callback to enable gradient checkpointing for compatible HuggingFace models."""

    def before_fit(self):
        """Enable gradient checkpointing on before_fit event."""
        
        # Check that huggingface model supports gradient checkpointing
        if not self.model.hf_model.supports_gradient_checkpointing:
            raise CheckpointingNotSupported()
            
        if self.model.hf_model.is_gradient_checkpointing == False:
            self.model.hf_model.gradient_checkpointing_enable()
            
    def after_fit(self):
        """Disable gradient checkpointing on after_fit event."""
        if self.model.hf_model.is_gradient_checkpointing:
            self.model.hf_model.gradient_checkpointing_disable()
    
    @staticmethod
    def supported(model: PreTrainedModel):
        """Tests whether a HuggingFace `PreTrainedModel` supports gradient checkpointing."""
        return model.supports_gradient_checkpointing

In [None]:
import gc
from fastai.text.all import *
import GPUtil as GPU
from blurr.text.modeling.all import *

In [None]:
def clear_memory():
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
def gpu_memory(device_idx=1):
    return GPU.getGPUs()[device_idx].memoryUsed

In [None]:
# Load Data
path = untar_data(URLs.IMDB_SAMPLE)
model_path = Path("models")
imdb_df = pd.read_csv(path / "texts.csv")

In [None]:
# Create Learner
learn = BlearnerForSequenceClassification.from_data(
    imdb_df, 
    "roberta-large", 
    dl_kwargs={"bs": 4}
)

In [None]:
# Train for a single epoch for baseline memory usage
learn.fit_one_cycle(1, lr_max=1e-3)

base_mem = gpu_memory()
print(f"{base_mem} MBs used.")
# Clear gpu memory
clear_memory()

epoch,train_loss,valid_loss,f1_score,accuracy,time
0,0.337935,0.203397,0.933333,0.94,00:25


8902.0 MBs used.


In [None]:
# Train with GradientCheckpointing
learn.fit_one_cycle(1, lr_max=1e-3, cbs=[GradientCheckpointing()])

check_mem = gpu_memory()
print(f"{check_mem} MBs used.")
assert base_mem > check_mem

epoch,train_loss,valid_loss,f1_score,accuracy,time
0,0.234959,0.209658,0.921348,0.93,00:34


3814.0 MBs used.


## Export -

In [None]:
# hide
from nbdev.export import notebook2script

notebook2script()


Converted 00_callbacks.ipynb.
Converted 00_utils.ipynb.
Converted 01_text-callbacks.ipynb.
Converted 01_text-utils.ipynb.
Converted 11_text-data-core.ipynb.
Converted 11_text-modeling-core.ipynb.
Converted 12_text-data-language-modeling.ipynb.
Converted 12_text-modeling-language-modeling.ipynb.
Converted 13_text-data-token-classification.ipynb.
Converted 13_text-modeling-token-classification.ipynb.
Converted 14_text-data-question-answering.ipynb.
Converted 14_text-modeling-question-answering.ipynb.
Converted 20_text-data-seq2seq-core.ipynb.
Converted 20_text-modeling-seq2seq-core.ipynb.
Converted 21_text-data-seq2seq-summarization.ipynb.
Converted 21_text-modeling-seq2seq-summarization.ipynb.
Converted 22_text-data-seq2seq-translation.ipynb.
Converted 22_text-modeling-seq2seq-translation.ipynb.
Converted 99a_text-examples-high-level-api.ipynb.
Converted 99b_text-examples-glue.ipynb.
Converted 99c_text-examples-glue-plain-pytorch.ipynb.
Converted 99d_text-examples-multilabel.ipynb.
Conv