# Initializing weights with LoftQ by replacing LoRA weights in-place

This notebook shows how to apply [LoftQ](https://arxiv.org/abs/2310.08659) initialization on our QLoRA model.

In short, the idea behind LoftQ is the following. When we use QLoRA, i.e. we quantize the base model with bitsandbytes to save memory, and then train LoRA weights on top of this base model, we expect a certain performance gap. This is partly due to the fact that quantization is onyl an approximation of the "real" weights and thus introduces a quantization error. By default, LoRA weights are initialized such that they are a no-op at the start of the training. However, we can instead initialize them so that they minimize the quantization error. This is the idea behind LoftQ.

Note that this only influences the initialization of the model. Everything that follows stays the same as always.

## Imports

In [1]:
import os
import torch

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from peft import get_peft_model, LoraConfig, replace_lora_weights_loftq

## Functions

In [4]:
def get_mae(x, y):
    return (x - y).abs().mean()


def get_mse(x, y):
    return torch.pow(x - y, 2).mean()


def error_report(x, y):
    mae = get_mae(x, y)
    mse = get_mse(x, y)
    print(
        f"Mean absolute error: {mae:>8.5f}\n"
        f"Mean squared error:  {mse:>8.5f}"
    )

## Base model

First, let's load a base model and calculate some logits. These logits are the baseline, i.e. we try to match their values as best as possible. We only need these logits for demonstration purposes. In practice, it is not necessary to load the non-quantized weights to apply LoftQ initialization.

**Note**: We have to choose a model with a `model.safetensors` file. As PyTorch checkpoints (pickle) cannot be loaded lazily, we have to use [safetensors](https://huggingface.co/docs/safetensors/index). If those don't exist for your model, save the pretrained model as a safetensors file using `safe_pretrained` and pass the model path to `replace_lora_weights_loftq`.

In [5]:
model_id = "bigscience/bloomz-560m"

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_id)

## baseline: Full Precision model

In [8]:
s = """Beautiful is better than ugly.
Explicit is better than implicit.
Simple is better than complex.
Complex is better than complicated.
Flat is better than nested.
Sparse is better than dense.
Readability counts.
Special cases aren't special enough to break the rules.
Although practicality beats purity.
Errors should never pass silently.
Unless explicitly silenced.
In the face of ambiguity, refuse the temptation to guess.
There should be one-- and preferably only one --obvious way to do it.
Although that way may not be obvious at first unless you're Dutch.
Now is better than never.
Although never is often better than *right* now.
If the implementation is hard to explain, it's a bad idea.
If the implementation is easy to explain, it may be a good idea.
Namespaces are one honking great idea -- let's do more of those!"""

In [9]:
inputs = tokenizer(s.splitlines(), return_tensors="pt", padding=True)

Our baseline logits:

In [10]:
logits_base = model(**inputs).logits

In [11]:
logits_base

tensor([[[278.0033, 283.7906, 297.8988,  ..., 161.5043, 161.5043, 161.4948],
         [278.0033, 283.7906, 297.8988,  ..., 161.5043, 161.5043, 161.4948],
         [278.0033, 283.7906, 297.8988,  ..., 161.5043, 161.5043, 161.4948],
         ...,
         [401.2244, 399.7368, 418.2438,  ..., 207.4830, 207.4830, 207.4720],
         [399.9124, 404.4676, 429.3989,  ..., 206.6692, 206.6694, 206.6578],
         [393.6312, 398.5472, 420.0659,  ..., 207.6603, 207.6605, 207.6490]],

        [[318.6148, 324.2510, 338.7469,  ..., 179.2935, 179.2938, 179.2829],
         [318.6148, 324.2510, 338.7469,  ..., 179.2935, 179.2938, 179.2829],
         [318.6148, 324.2510, 338.7469,  ..., 179.2935, 179.2938, 179.2829],
         ...,
         [397.1699, 397.2800, 416.4278,  ..., 203.6698, 203.6704, 203.6588],
         [399.3676, 403.9751, 427.0152,  ..., 205.7427, 205.7434, 205.7319],
         [400.2210, 402.1478, 423.6199,  ..., 206.3044, 206.3050, 206.2930]],

        [[265.3966, 271.5363, 285.0104,  ...

## Normal LoRA model

Now we load the model quantized with bitsandbytes. For now, only 4bit is supported.

In [8]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
)

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)

## PTQ model

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
`low_cpu_mem_usage` was None, now default to True since model is quantized.


Next we create a LoRA model using PEFT and compute the logits of that model.

In [11]:
lora_config = LoraConfig(task_type="CAUSAL_LM", target_modules="all-linear")

## QLoRA

In [15]:
peft_model = get_peft_model(model, lora_config)

In [16]:
logits_lora = peft_model(**inputs).logits

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)

Let's check the influence of the quantization error on our logits:

In [16]:
error_report(logits_base, logits_lora)

Mean absolute error:  3.61113
Mean squared error:  36.53259


## LoftQ

Next, let's use LoftQ initialization and see if it helps reduce the error.

In [17]:
replace_lora_weights_loftq(peft_model)

In [18]:
logits_loftq = peft_model(**inputs).logits

In [19]:
error_report(logits_base, logits_loftq)

Mean absolute error:  3.24111
Mean squared error:  31.13725


We can see that LoftQ initialization helped a little bit, but the difference is not huge.

## LoftQ with callback

To help with this, let's write a small callback function and pass it to `replace_lora_weights_loftq`. What this function does is that each time one weight is being replaced with LoftQ-initialized weights, we perform a test if the quantization error is actually reduced. If it it is not, we roll back the replacement. This way, we keep only those replacements that improve the results.

In [9]:
# Since PEFT has modified the base model, we should reload it
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)

`low_cpu_mem_usage` was None, now default to True since model is quantized.


In [12]:
peft_model = get_peft_model(model, lora_config)

# QLoRA (기본본)

In [13]:
current_mse = float("inf")

In [14]:
def my_callback(model, module_name):
    """Callable to replace weights with LoFTQ if the mse is lower than the current best one."""
    global current_mse

    logits = model(**inputs).logits
    mse = get_mse(logits_base, logits)
    if mse < current_mse:
        current_mse = mse
        print(f"MSE improved for module {module_name}")
        return True
    print(f"MSE did not improve for module {module_name}")
    return False

In [None]:

replace_lora_weights_loftq(peft_model, callback=my_callback)

In [25]:
logits_loftq_callback = peft_model(**inputs).logits

In [26]:
error_report(logits_base, logits_loftq_callback)

Mean absolute error:  1.79576
Mean squared error:   8.47075


We can see that applying LoftQ with the help of the callback reduced the error quite significantly.

## Applying LoftQ multiple times

It is possible to run `replace_lora_weights_loftq` multiple times on the same model when using the callback.

In [27]:
replace_lora_weights_loftq(peft_model, callback=my_callback)

MSE did not improve for module transformer.h.0.self_attention.query_key_value
MSE did not improve for module transformer.h.0.self_attention.dense
MSE did not improve for module transformer.h.0.mlp.dense_h_to_4h
MSE did not improve for module transformer.h.0.mlp.dense_4h_to_h
MSE improved for module transformer.h.1.self_attention.query_key_value
MSE did not improve for module transformer.h.1.self_attention.dense
MSE did not improve for module transformer.h.1.mlp.dense_h_to_4h
MSE did not improve for module transformer.h.1.mlp.dense_4h_to_h
MSE did not improve for module transformer.h.2.self_attention.query_key_value
MSE did not improve for module transformer.h.2.self_attention.dense
MSE did not improve for module transformer.h.2.mlp.dense_h_to_4h
MSE did not improve for module transformer.h.2.mlp.dense_4h_to_h
MSE did not improve for module transformer.h.3.self_attention.query_key_value
MSE did not improve for module transformer.h.3.self_attention.dense
MSE did not improve for module tr

In [28]:
logits_loftq_callback_twice = peft_model(**inputs).logits

In [29]:
error_report(logits_base, logits_loftq_callback_twice)

Mean absolute error:  1.76357
Mean squared error:   8.33938


There are further gains, but they are not very big.