In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [None]:
!pip install -Uq transformers datasets torch numpy

In [None]:
import torch, transformers, datasets
import numpy as np
torch.__version__, transformers.__version__, datasets.__version__, np.__version__

In [None]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, TrainingArguments, Trainer, set_seed
from torch import nn
from torch.nn import functional as F
import math

hf_ckp = 'roberta-base'
set_seed(100)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {f"accuracy": (predictions == labels).mean()}

def count_parameters(m, verbose=True):
    total_count = 0
    learnable_count = 0
    if verbose:
        print("Parameters (name, tunable, count):")

    output_width = max([len(n) for n, _ in m.named_parameters()])
    for n, p in m.named_parameters():
        count = p.data.numel()
        if verbose:
            print(f" {n:{output_width}} {p.requires_grad:5b} {count:>11d}")
        total_count += count
        if p.requires_grad:
            learnable_count += count

    print(
        f"Total parameters: {total_count:,}, "
        f"thereof learnable: {learnable_count:,} "
        f"({learnable_count/total_count*100.:5.4f}%)"
    )

    return total_count, learnable_count

def adapt_model(model):
    
    # Minimalized example in place of the original LoRA-from-Scratch 
    # implementation from the article: 
    # https://towardsdatascience.com/dive-into-lora-adapters-38f4da488ede
    class MinimalLoRAAdapter(nn.Module): 
        def __init__(self, 
                     adaptee):
            super().__init__()

            self.adaptee = adaptee

            self.orig_forward = adaptee.forward
            adaptee.forward = self.forward
            
            r = 1
            adaptee.lora_A = nn.Parameter(
                torch.randn(adaptee.in_features, r) / math.sqrt(adaptee.in_features)
            )
            adaptee.lora_B = nn.Parameter(torch.zeros(r, adaptee.out_features))

        def forward(self, x, *args, **kwargs):
            return (
                self.orig_forward(x, *args, **kwargs)
                + F.dropout(x, 0.1) @ self.adaptee.lora_A @ self.adaptee.lora_B
            )
   
    # freeze all layers, incl. embeddings, except for the classifier
    for m in model.roberta.modules():    
        m.requires_grad_(False)

    # Adapt linear modules in transformer layers
    for m in model.roberta.encoder.modules():    
        if isinstance(m, nn.Linear):
            MinimalLoRAAdapter(m)

In [None]:
%%time

tokenizer = AutoTokenizer.from_pretrained(hf_ckp)
collator = DataCollatorWithPadding(tokenizer=tokenizer)

datasets.logging.disable_progress_bar()
dataset = datasets.load_dataset("glue", "sst2")
train = dataset["train"]
valid = dataset["validation"]

def preprocess_function(examples):
        return tokenizer(examples['sentence'], padding=False, truncation=True)

tokenized_train = train.map(preprocess_function, batched=False)
tokenized_valid = valid.map(preprocess_function, batched=False)

def train(cp_enabled, model):
     
    model = AutoModelForSequenceClassification.from_pretrained(hf_ckp, num_labels=2)   
    gradient_checkpointing_kwargs = None
    if cp_enabled:
        model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant":False})
        # https://github.com/huggingface/transformers/issues/26221#issuecomment-2031611304
        gradient_checkpointing_kwargs = {"use_reentrant":False}

    training_args = TrainingArguments(
        gradient_checkpointing=cp_enabled,
        gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
        output_dir="out",    
        per_device_train_batch_size=10,
        learning_rate=3e-5,
        save_steps=10_000,
        eval_steps=   250,
        max_steps = 1_500,
        evaluation_strategy="steps",
        save_strategy="steps",
        save_total_limit=1,
        disable_tqdm=True,
        metric_for_best_model='eval_accuracy',
        report_to="none", # Disable wandb, tensorboard
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_valid,
        tokenizer=tokenizer,
        data_collator=collator,
        compute_metrics=compute_metrics,
    )
    print(f'{model.is_gradient_checkpointing=}')
    total, learnable = count_parameters(model, verbose=False)
    
    adapt_model(model)
   
    trainer.train()
    trainer.evaluate()
    del(model) # essential!


print('\n---- without gradient checkpointing ----\n')
train(False, None)

print('\n---- with gradient checkpointing ----\n')
train(True, None)