In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
!pip install -Uq transformers datasets torch numpy

In [3]:
import torch, transformers, datasets
import numpy as np
torch.__version__, transformers.__version__, datasets.__version__, np.__version__

('2.1.2+cu121', '4.36.2', '2.16.1', '1.26.3')

In [4]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, TrainingArguments, Trainer, set_seed
from torch import nn
from torch.nn import functional as F
import math

hf_ckp = 'roberta-base'
set_seed(100)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {f"accuracy": (predictions == labels).mean()}

def count_parameters(m, verbose=True):
    total_count = 0
    learnable_count = 0
    if verbose:
        print("Parameters (name, tunable, count):")

    output_width = max([len(n) for n, _ in m.named_parameters()])
    for n, p in m.named_parameters():
        count = p.data.numel()
        if verbose:
            print(f" {n:{output_width}} {p.requires_grad:5b} {count:>11d}")
        total_count += count
        if p.requires_grad:
            learnable_count += count

    print(
        f"Total parameters: {total_count:,}, "
        f"thereof learnable: {learnable_count:,} "
        f"({learnable_count/total_count*100.:5.4f}%)"
    )

    return total_count, learnable_count

def adapt_model(model):
    
    # Minimalized example in place of the original LoRA-from-Scratch 
    # implementation from the article: 
    # https://towardsdatascience.com/dive-into-lora-adapters-38f4da488ede
    class MinimalLoRAAdapter(nn.Module): 
        def __init__(self, 
                     adaptee):
            super().__init__()

            self.adaptee = adaptee

            self.orig_forward = adaptee.forward
            adaptee.forward = self.forward
            
            r = 1
            adaptee.lora_A = nn.Parameter(
                torch.randn(adaptee.in_features, r) / math.sqrt(adaptee.in_features)
            )
            adaptee.lora_B = nn.Parameter(torch.zeros(r, adaptee.out_features))

        def forward(self, x, *args, **kwargs):
            return (
                self.orig_forward(x, *args, **kwargs)
                + F.dropout(x, 0.1) @ self.adaptee.lora_A @ self.adaptee.lora_B
            )
   
    # freeze all layers, incl. embeddings, except for the classifier
    for m in model.roberta.modules():    
        m.requires_grad_(False)

    # Adapt linear modules in transformer layers
    for m in model.roberta.encoder.modules():    
        if isinstance(m, nn.Linear):
            MinimalLoRAAdapter(m)

2024-01-17 13:07:33.587614: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [5]:
%%time

tokenizer = AutoTokenizer.from_pretrained(hf_ckp)
collator = DataCollatorWithPadding(tokenizer=tokenizer)

datasets.logging.disable_progress_bar()
dataset = datasets.load_dataset("glue", "sst2")
train = dataset["train"]
valid = dataset["validation"]

def preprocess_function(examples):
        return tokenizer(examples['sentence'], padding=False, truncation=True)

tokenized_train = train.map(preprocess_function, batched=False)
tokenized_valid = valid.map(preprocess_function, batched=False)

def train(cp_enabled, model):
     
    if cp_enabled:
        model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant":False})
    
    training_args = TrainingArguments(
        gradient_checkpointing=cp_enabled,
        output_dir="out",    
        per_device_train_batch_size=224,
        learning_rate=3e-5,
        save_steps=10_000,
        eval_steps=   250,
        max_steps = 1_000,
        evaluation_strategy="steps",
        save_strategy="steps",
        save_total_limit=1,
        disable_tqdm=True,
        metric_for_best_model='eval_accuracy',
        report_to="none", # Disable wandb, tensorboard
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_valid,
        tokenizer=tokenizer,
        data_collator=collator,
        compute_metrics=compute_metrics,
    )
    print(f'{model.is_gradient_checkpointing=}')
    total, learnable = count_parameters(model, verbose=False)
    
    trainer.train()
    trainer.evaluate()


print('\n---- without gradient checkpointing ----\n')
model = AutoModelForSequenceClassification.from_pretrained(hf_ckp, num_labels=2)   
adapt_model(model)
train(False, model)

del(model) # essential!

print('\n---- with gradient checkpointing ----\n')
model = AutoModelForSequenceClassification.from_pretrained(hf_ckp, num_labels=2)
adapt_model(model)

train(True, model)


---- without gradient checkpointing ----



Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.out_proj.weight', 'classifier.out_proj.bias', 'classifier.dense.bias', 'classifier.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.is_gradient_checkpointing=False
Total parameters: 124,813,058, thereof learnable: 758,018 (0.6073%)


You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'eval_loss': 0.33385390043258667, 'eval_accuracy': 0.8841743119266054, 'eval_runtime': 1.7162, 'eval_samples_per_second': 508.097, 'eval_steps_per_second': 63.512, 'epoch': 0.83}
{'loss': 0.4722, 'learning_rate': 1.5e-05, 'epoch': 1.66}
{'eval_loss': 0.2571539282798767, 'eval_accuracy': 0.9013761467889908, 'eval_runtime': 1.6767, 'eval_samples_per_second': 520.059, 'eval_steps_per_second': 65.007, 'epoch': 1.66}
{'eval_loss': 0.2381727248430252, 'eval_accuracy': 0.908256880733945, 'eval_runtime': 1.6953, 'eval_samples_per_second': 514.363, 'eval_steps_per_second': 64.295, 'epoch': 2.49}
{'loss': 0.2937, 'learning_rate': 0.0, 'epoch': 3.32}
{'eval_loss': 0.23466329276561737, 'eval_accuracy': 0.9071100917431193, 'eval_runtime': 1.692, 'eval_samples_per_second': 515.359, 'eval_steps_per_second': 64.42, 'epoch': 3.32}
{'train_runtime': 457.1886, 'train_samples_per_second': 489.951, 'train_steps_per_second': 2.187, 'train_loss': 0.38296363830566404, 'epoch': 3.32}
{'eval_loss': 0.235939592

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.out_proj.weight', 'classifier.out_proj.bias', 'classifier.dense.bias', 'classifier.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.is_gradient_checkpointing=True
Total parameters: 124,813,058, thereof learnable: 758,018 (0.6073%)




{'eval_loss': 0.6846632361412048, 'eval_accuracy': 0.5091743119266054, 'eval_runtime': 1.6902, 'eval_samples_per_second': 515.902, 'eval_steps_per_second': 64.488, 'epoch': 0.83}
{'loss': 0.6764, 'learning_rate': 1.5e-05, 'epoch': 1.66}
{'eval_loss': 0.6755141615867615, 'eval_accuracy': 0.5091743119266054, 'eval_runtime': 1.6956, 'eval_samples_per_second': 514.279, 'eval_steps_per_second': 64.285, 'epoch': 1.66}
{'eval_loss': 0.6665772199630737, 'eval_accuracy': 0.5103211009174312, 'eval_runtime': 1.6849, 'eval_samples_per_second': 517.534, 'eval_steps_per_second': 64.692, 'epoch': 2.49}
{'loss': 0.6586, 'learning_rate': 0.0, 'epoch': 3.32}
{'eval_loss': 0.6635248064994812, 'eval_accuracy': 0.5194954128440367, 'eval_runtime': 1.6993, 'eval_samples_per_second': 513.142, 'eval_steps_per_second': 64.143, 'epoch': 3.32}
{'train_runtime': 227.8506, 'train_samples_per_second': 983.101, 'train_steps_per_second': 4.389, 'train_loss': 0.6675097045898437, 'epoch': 3.32}
{'eval_loss': 0.663524806