# Implementation

[![Twitter Handle](https://img.shields.io/badge/Twitter-@gaohongnan-blue?style=social&logo=twitter)](https://twitter.com/gaohongnan)
[![LinkedIn Profile](https://img.shields.io/badge/@gaohongnan-blue?style=social&logo=linkedin)](https://linkedin.com/in/gao-hongnan)
[![GitHub Profile](https://img.shields.io/badge/GitHub-gao--hongnan-lightgrey?style=social&logo=github)](https://github.com/gao-hongnan)
![Tag](https://img.shields.io/badge/Tag-Brain_Dump-red)
![Tag](https://img.shields.io/badge/Level-Beginner-green)
[![Code](https://img.shields.io/badge/View-Code-blue?style=flat-square&logo=github)](https://github.com/gao-hongnan/omniverse/tree/main/omnivault/modules/lora.py)

```{contents}
:local:
```

## Merge And Quantize

In [37]:
# %pip install -U omniverse

## Dependencies

In [38]:
from __future__ import annotations

import copy
import math
from typing import Any, Dict, List, Optional, TypedDict, Union

import numpy as np
import psutil
import torch
from datasets import load_dataset
from pydantic import BaseModel, Field
from rich.pretty import pprint
from scipy.special import softmax
from sklearn.metrics import (
    accuracy_score,
    auc,
    average_precision_score,
    brier_score_loss,
    confusion_matrix,
    f1_score,
    log_loss,
    precision_recall_curve,
    precision_score,
    recall_score,
    roc_auc_score,
    roc_curve,
)
from torch import nn
from transformers import (
    DataCollatorWithPadding,
    Qwen2ForSequenceClassification,
    Qwen2Tokenizer,
    Trainer,
    TrainingArguments,
)
from transformers.trainer_utils import EvalPrediction

from omnivault.utils.reproducibility.seed import seed_all

## Setting Up

In [39]:
seed_all(42, seed_torch=True, set_torch_deterministic=False)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

MAX_LENGTH = 32
PADDING = "longest"
BATCH_SIZE = 32
TRUNCATION = True
RETURN_TENSORS = "pt"

## Dataset Preparation

In [40]:
class Batch(TypedDict):
    sentence: List[str]
    labels: List[int]


class TokenizedBatch(TypedDict):
    input_ids: List[int]
    attention_mask: List[int]
    labels: List[int]

tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen1.5-0.5B", padding_side="left")

def preprocess_function(batch: Batch, **kwargs: Any) -> TokenizedBatch:
    return tokenizer(batch["sentence"], **kwargs)

dataset = load_dataset("financial_phrasebank", "sentences_allagree", trust_remote_code=True)["train"]
dataset = dataset.rename_column("label", "labels")

train_valid_split = dataset.train_test_split(test_size=0.1, shuffle=True, stratify_by_column="labels")

train_dataset = train_valid_split["train"]
valid_dataset = train_valid_split["test"]

tokenized_train_dataset = train_dataset.map(
    preprocess_function,
    fn_kwargs={"truncation": TRUNCATION, "padding": PADDING, "max_length": MAX_LENGTH},
    batched=True,
    num_proc=psutil.cpu_count(logical=True),
    batch_size=1000,
).remove_columns(["sentence"])

tokenized_valid_dataset = valid_dataset.map(
    preprocess_function,
    fn_kwargs={"truncation": TRUNCATION, "padding": PADDING, "max_length": MAX_LENGTH},
    batched=True,
    num_proc=psutil.cpu_count(logical=True),
    batch_size=1000,
).remove_columns(["sentence"])

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

id2label = {0: "negative", 1: "neutral", 2: "positive"}
label2id = {"negative": 0, "neutral": 1, "positive": 2}
num_labels = len(id2label)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Map (num_proc=8):   0%|          | 0/2037 [00:00<?, ? examples/s]

Map (num_proc=8):   0%|          | 0/227 [00:00<?, ? examples/s]

## Base Model

In [41]:
base_model = Qwen2ForSequenceClassification.from_pretrained(
    "Qwen/Qwen1.5-0.5B",
    id2label=id2label,
    label2id=label2id,
    num_labels=num_labels,
    problem_type="single_label_classification",
)
base_model.config.pad_token_id = tokenizer.pad_token_id

base_model = base_model.to(DEVICE)
pprint(base_model)

Some weights of Qwen2ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen1.5-0.5B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [42]:
def total_trainable_parameters(module: nn.Module) -> int:
    """Returns the number of trainable parameters in the model."""
    return sum(p.numel() for p in module.parameters() if p.requires_grad)


def total_parameters(module: nn.Module) -> int:
    """Returns the total number of parameters in the model, including non-trainable."""
    return sum(p.numel() for p in module.parameters())

base_model_total_trainable = total_trainable_parameters(base_model)
print(f"Total trainable parameters before LoRA: {base_model_total_trainable:,}")

Total trainable parameters before LoRA: 463,990,784


## Metrics

In [43]:
def compute_metrics_for_single_label_classification(eval_prediction: EvalPrediction) -> Dict[str, float | List[float]]:
    logits, labels = eval_prediction.predictions, eval_prediction.label_ids
    probs = softmax(logits, axis=-1)

    num_classes = logits.shape[1]
    preds = np.argmax(probs, axis=1)

    metrics = {
        "eval_log_loss": log_loss(labels, probs),
        "eval_accuracy": accuracy_score(labels, preds),
        "eval_precision_macro": precision_score(labels, preds, average="macro", zero_division=0),
        "eval_recall_macro": recall_score(labels, preds, average="macro", zero_division=0),
        "eval_f1_score_macro": f1_score(labels, preds, average="macro", zero_division=0),
        "eval_precision_micro": precision_score(labels, preds, average="micro", zero_division=0),
        "eval_recall_micro": recall_score(labels, preds, average="micro", zero_division=0),
        "eval_f1_score_micro": f1_score(labels, preds, average="micro", zero_division=0),
        "eval_confusion_matrix": confusion_matrix(labels, preds).tolist(),
        "eval_roc_auc": roc_auc_score(labels, probs, multi_class="ovr"),
        "eval_pr_auc": average_precision_score(labels, probs, average="macro")
    }

    if num_classes == 2:
        metrics["eval_brier_score"] = brier_score_loss(labels, probs[:, 1], pos_label=1)
    else:
        brier_scores = [brier_score_loss(labels == i, probs[:, i]) for i in range(num_classes)]
        metrics["eval_brier_score"] = np.mean(brier_scores)

    if num_classes > 2:
        for class_index in range(num_classes):
            fpr, tpr, _ = roc_curve(labels == class_index, probs[:, class_index])
            roc_auc = auc(fpr, tpr)
            precision, recall, _ = precision_recall_curve(labels == class_index, probs[:, class_index])
            pr_auc = auc(recall, precision)
            metrics[f"eval_roc_auc_class_{class_index}"] = roc_auc
            metrics[f"eval_pr_auc_class_{class_index}"] = pr_auc

    return metrics

## Evaluate With Pretrained Model

In [44]:
trainer = Trainer(
    model=base_model,
    data_collator=data_collator,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_valid_dataset,
    compute_metrics=compute_metrics_for_single_label_classification,
)

valid_metrics = trainer.predict(tokenized_valid_dataset, metric_key_prefix="eval")
pprint(valid_metrics.metrics)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


  0%|          | 0/29 [00:00<?, ?it/s]

## LoRA Implementation

In [45]:
class LoraConfig(BaseModel):
    r: int = Field(..., description="Lora attention dimension (the 'rank').")
    lora_alpha: int = Field(..., description="The alpha parameter for Lora scaling.")
    lora_dropout: float = Field(..., description="The dropout probability for Lora layers.")
    target_modules: List[str] = Field(
        default=None,
        description=(
            "The names of the modules to apply the adapter to. If specified, only the modules with the specified "
            "names will be replaced. When passing a string, a regex match will be performed. When passing a list of "
            "strings, either an exact match will be performed or it is checked if the name of the module ends with any "
            "of the passed strings. If specified as 'all-linear', all linear/Conv1D modules are chosen, excluding the "
            "output layer. If not specified, modules are chosen according to the model architecture. If the architecture "
            "is unknown, an error will be raised—manual specification of target modules is required in such cases."
        ),
    )
    linear_bias: bool = Field(default=True, description="To include linear bias or not.")
    modules_to_save: List[str] = Field(
        default=None,
        description=(
            """List of modules apart from adapter layers to be set as
               trainable and saved in the final checkpoint."""
        ),
    )

In [47]:
lora_config = LoraConfig(
    r=4, lora_alpha=8, lora_dropout=0.1, target_modules=["q_proj", "k_proj", "v_proj"], modules_to_save=["score"]
)
pprint(lora_config)

We print out the target modules below. For simplicity, we target only the `q`, `k` and `v` layers for now.

In [49]:
for module_name, _module in base_model.named_modules():
    if any(target_module in module_name for target_module in lora_config.target_modules):
        print(module_name)

model.layers.0.self_attn.q_proj
model.layers.0.self_attn.k_proj
model.layers.0.self_attn.v_proj
model.layers.1.self_attn.q_proj
model.layers.1.self_attn.k_proj
model.layers.1.self_attn.v_proj
model.layers.2.self_attn.q_proj
model.layers.2.self_attn.k_proj
model.layers.2.self_attn.v_proj
model.layers.3.self_attn.q_proj
model.layers.3.self_attn.k_proj
model.layers.3.self_attn.v_proj
model.layers.4.self_attn.q_proj
model.layers.4.self_attn.k_proj
model.layers.4.self_attn.v_proj
model.layers.5.self_attn.q_proj
model.layers.5.self_attn.k_proj
model.layers.5.self_attn.v_proj
model.layers.6.self_attn.q_proj
model.layers.6.self_attn.k_proj
model.layers.6.self_attn.v_proj
model.layers.7.self_attn.q_proj
model.layers.7.self_attn.k_proj
model.layers.7.self_attn.v_proj
model.layers.8.self_attn.q_proj
model.layers.8.self_attn.k_proj
model.layers.8.self_attn.v_proj
model.layers.9.self_attn.q_proj
model.layers.9.self_attn.k_proj
model.layers.9.self_attn.v_proj
model.layers.10.self_attn.q_proj
model.l

In [54]:
"""LoRA: Low-Rank Adaptation of Large Language Models.

References
----------
[1] https://pytorch.org/torchtune/stable/tutorials/lora_finetune.html
"""


from __future__ import annotations

import math
from typing import List

import torch
from torch import nn


def _lora_a_init_params(x: nn.Linear) -> None:
    """
    Initialize LoRA A weight to Kaiming uniform.
    """
    nn.init.kaiming_uniform_(x.weight, a=math.sqrt(5))


def _lora_b_init_params(x: nn.Linear) -> None:
    """
    Initialize LoRA B weight to zeros.
    """
    nn.init.zeros_(x.weight)


class LoRALinear(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, bias: bool, rank: int, alpha: float, dropout: float) -> None:
        super().__init__()

        # These are the weights from the original pretrained model
        self.linear = nn.Linear(in_dim, out_dim, bias=bias)  # weight shape=[out_dim, in_dim]

        # These are the new LoRA params. In general rank << in_dim, out_dim - do not put bias here
        self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)  # weight shape=[rank, in_dim]
        self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)  # weight shape=[out_dim, rank]

        self.rank = rank
        self.alpha = alpha
        self.dropout = nn.Dropout(p=dropout)

        self._init_weights()

    def _init_weights(self) -> None:
        """See https://github.com/microsoft/LoRA/blob/4c0333854cb905966f8cc4e9a74068c1e507c7b7/loralib/layers.py#L119."""

        _lora_a_init_params(self.lora_a)
        _lora_b_init_params(self.lora_b)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # This would be the output of the original model
        frozen_out = x @ self.linear.weight.T
        if self.linear.bias is not None:
            frozen_out += self.linear.bias

        # lora_a projects inputs down to the much smaller self.rank,
        # then lora_b projects back up to the output dimension
        x = self.dropout(x)
        lora_out = (x @ self.lora_a.weight.T) @ self.lora_b.weight.T
        # Finally, scale by the alpha parameter (normalized by rank)
        # and add to the original model's outputs
        return frozen_out + (self.alpha / self.rank) * lora_out


def apply_lora_to_base_model(
    model: nn.Module, rank: int, alpha: float, dropout: float, target_modules: List[str] | None = None
) -> None:
    """Recursively apply LoRA to a model. Only supports applying on `nn.Linear` layers."""

    for module_name, module in model.named_children():
        if isinstance(module, nn.Linear):
            if target_modules is None or any(target in module_name for target in target_modules):
                setattr(
                    model,
                    module_name,
                    LoRALinear(
                        in_dim=module.in_features,
                        out_dim=module.out_features,
                        rank=rank,
                        alpha=alpha,
                        dropout=dropout,
                        bias=module.bias is not None,
                    ),
                )
        else:
            # Recursively apply LoRA to children modules
            apply_lora_to_base_model(model=module, rank=rank, alpha=alpha, dropout=dropout, target_modules=target_modules)


In [51]:
base_model_with_adapter = copy.deepcopy(base_model)

We apply recursively the `LoRA` module to the `q`, `k` and `v` layers
via `apply_lora_to_base_model`.

In [55]:
apply_lora_to_base_model(
    model=base_model_with_adapter,
    rank=lora_config.r,
    alpha=lora_config.lora_alpha,
    dropout=lora_config.lora_dropout,
    target_modules=lora_config.target_modules,
)

In [56]:
base_model_with_adapter_total_trainable = total_trainable_parameters(base_model_with_adapter)
print(f"Total trainable parameters after LoRA before freezing: {base_model_with_adapter_total_trainable:,}")

Total trainable parameters after LoRA before freezing: 464,580,608


First `bias` is default to `True` in original model, but in LoRA we need to have it as `False`. 
You also see that currently the total trainable parameters are more than base model. Why? 

In [58]:
base_model_with_adapter_total_trainable - base_model_total_trainable

589824

In [67]:
dim = base_model_with_adapter.model.layers[0].self_attn.q_proj.linear.weight.shape[0]
layers = base_model_with_adapter.model.layers.__len__()
rank = lora_config.r
num_target_modules = len(lora_config.target_modules)

qkv_lora_weight_params = (dim * rank * 2) * layers * num_target_modules # 2 is the AB 1 each

base_model_with_adapter_total_trainable - base_model_total_trainable ==  qkv_lora_weight_params

True

The additional parameters is basically because we apply to `qkv` where each `qkv` has 24 layers each, so for each layer, say `q_proj` we would have an additional
of `1024 * 4 * 2` because matrix A and B are mirrored to version of `[dim, rank]`. 

Now of course the next step is to freeze the base pretrained weights.
Note we DO NOT want to freeze the `score` module as that is our classification head.

In [68]:
for parameter_name, parameter in base_model_with_adapter.named_parameters():
    # We will set requires_grad to False if 'lora_' is not in the parameter name AND the parameter name does not contain any of the module names specified in modules_to_save
    if "lora_" not in parameter_name and not any(
        module_name in parameter_name for module_name in lora_config.modules_to_save
    ):
        parameter.requires_grad = False
    else:
        # Safeguard here parameters that are part of LoRA or specified modules are trainable
        parameter.requires_grad = True

In [69]:
base_model_with_adapter_total_trainable = total_trainable_parameters(base_model_with_adapter)
print(f"Total trainable parameters after LoRA after freezing: {base_model_with_adapter_total_trainable:,}")

Total trainable parameters after LoRA after freezing: 592,896


In [71]:
(base_model_with_adapter_total_trainable / base_model_total_trainable) * 100

0.1277818483567122

We are only training on `~0.1277%` of the total parameters.

## Train LoRA

In [None]:
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    modules_to_save=["score"],
)
pprint(lora_config)

In [None]:
training_args = TrainingArguments(
    do_eval=True,
    do_predict=False,
    do_train=True,
    warmup_ratio=0.0,
    learning_rate=6e-4,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=10,
    report_to="none",
    output_dir="./artifacts",
    overwrite_output_dir=True,
    gradient_accumulation_steps=1,
    logging_steps=25,
    evaluation_strategy="steps",
    eval_steps=32,
    save_strategy="steps",
    save_steps=128,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    lr_scheduler_type="cosine",
    weight_decay=0.0,
    save_total_limit=2,
    seed=42,
    data_seed=42,
    half_precision_backend="auto",
    optim="adamw_torch",
    label_smoothing_factor=0.0,
    max_grad_norm=1.0,
)

In [46]:
trainer = Trainer(
    model=base_model_with_adapter,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_valid_dataset,
    compute_metrics=compute_metrics_for_single_label_classification,
)

In [48]:
trainer.train()

Step,Training Loss,Validation Loss,Log Loss,Accuracy,Precision Macro,Recall Macro,F1 Score Macro,Precision Micro,Recall Micro,F1 Score Micro,Confusion Matrix,Roc Auc,Pr Auc,Brier Score,Roc Auc Class 0,Pr Auc Class 0,Roc Auc Class 1,Pr Auc Class 1,Roc Auc Class 2,Pr Auc Class 2
32,3.492,1.867992,1.867991,0.436123,0.429408,0.397243,0.374103,0.436123,0.436123,0.436123,"[[5, 9, 16], [4, 60, 76], [7, 16, 34]]",0.597633,0.434514,0.284848,0.537902,0.18168,0.668309,0.717904,0.586687,0.380112
64,1.3967,1.043167,1.043167,0.590308,0.419821,0.488931,0.413872,0.590308,0.590308,0.590308,"[[0, 3, 27], [0, 85, 55], [0, 8, 49]]",0.780956,0.545472,0.19595,0.700677,0.241,0.865928,0.894582,0.776264,0.482198
96,1.0635,0.784113,0.784113,0.757709,0.808035,0.553718,0.575001,0.757709,0.757709,0.757709,"[[5, 11, 14], [0, 138, 2], [0, 28, 29]]",0.848063,0.663932,0.131309,0.81489,0.475911,0.905255,0.913337,0.824045,0.588326
128,0.7343,1.009085,1.009085,0.731278,0.881481,0.532749,0.579467,0.731278,0.731278,0.731278,"[[9, 20, 1], [0, 140, 0], [0, 40, 17]]",0.895576,0.791378,0.148736,0.867513,0.659774,0.940887,0.956902,0.878328,0.751395
160,0.5026,1.525477,1.525477,0.585903,0.397773,0.569841,0.408853,0.585903,0.585903,0.585903,"[[29, 1, 0], [36, 104, 0], [52, 5, 0]]",0.893914,0.804567,0.219748,0.90643,0.707196,0.933826,0.958734,0.841486,0.742477
192,0.8881,0.520843,0.520843,0.828194,0.790419,0.766291,0.758882,0.828194,0.828194,0.828194,"[[23, 5, 2], [9, 131, 0], [10, 13, 34]]",0.933432,0.861995,0.088546,0.91709,0.735546,0.943268,0.957575,0.939938,0.888554
224,0.567,0.563086,0.563086,0.784141,0.73008,0.794486,0.747735,0.784141,0.784141,0.784141,"[[23, 0, 7], [10, 106, 24], [5, 3, 49]]",0.941179,0.882236,0.099551,0.948393,0.822162,0.967816,0.965795,0.907327,0.855146
256,0.4873,0.777104,0.777104,0.779736,0.779006,0.691688,0.646585,0.779736,0.779736,0.779736,"[[25, 5, 0], [3, 137, 0], [24, 18, 15]]",0.932729,0.8667,0.11096,0.935195,0.808896,0.966913,0.973352,0.896078,0.814109
288,0.3015,0.51801,0.518011,0.881057,0.832276,0.789348,0.806632,0.881057,0.881057,0.881057,"[[18, 2, 10], [0, 137, 3], [5, 7, 45]]",0.960663,0.891463,0.06648,0.954653,0.788685,0.980172,0.98332,0.947162,0.898182
320,0.2765,0.517759,0.517759,0.859031,0.815993,0.774102,0.792039,0.859031,0.859031,0.859031,"[[20, 4, 6], [1, 136, 3], [6, 12, 39]]",0.950659,0.884996,0.072842,0.946701,0.812426,0.971429,0.976206,0.933849,0.862842


TrainOutput(global_step=640, training_loss=0.47741791264852507, metrics={'train_runtime': 267.3562, 'train_samples_per_second': 76.19, 'train_steps_per_second': 2.394, 'total_flos': 1235801533317120.0, 'train_loss': 0.47741791264852507, 'epoch': 10.0})

The accuracy hits around $90\%$ after 10 epochs. This is a far cry from what
an encoder like Deberta can achieve. But this is a good start and shows that
the implementation is working.