# Training Large Language Models in 2bit with `aqlm`, `transformers` and `PEFT`

<a target="_blank" href="https://colab.research.google.com/github/Vahe1994/AQLM/blob/main/notebooks/aqlm_2bit_training.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

Welcome to this notebook that goes through the recent `aqlm` integration that introduces minimal performance degradation 2bit quantization techniques.

In this notebook, we will learn how to load a large model in 2bit (`Mixtral-8x7b`) and train it using Google Colab and PEFT library from Hugging Face 🤗.


**Install the `aqlm` library**
- It's the only extra dependency to run AQLM models.
- Add `[gpu]` to install the required CUDA specific dependencies.
- Install the latest `accelerate` and `transformers` releases to properly support it.

In [1]:
%%capture
!pip install aqlm[gpu]>=1.1.0
!pip install git+https://github.com/huggingface/peft.git@main
!pip install accelerate>=0.27.0
!pip install git+https://github.com/huggingface/transformers.git@main
!pip install datasets
!pip install bitsandbytes # for 8-bit optimizer only
!pip install torch --upgrade
!pip install huggingface --upgrade

In [18]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "ISTA-DASLab/Meta-Llama-3-8B-Instruct-AQLM-2Bit-1x16"
model_id = "meta-llama/Meta-Llama-3-8B"


tokenizer = AutoTokenizer.from_pretrained(model_id, token="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")
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype="auto", low_cpu_mem_usage=True, token="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")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


config.json:   0%|          | 0.00/654 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/177 [00:00<?, ?B/s]

In [34]:
from datasets import load_dataset

data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)



Map:   0%|                                                                                                        | 0/2508 [00:00<?, ? examples/s]


NameError: name 'tokenizer' is not defined

In [None]:
data["train"]["quote"]

**Add LoRA**

To alter model's behavior, we have to make it trainable. We can do that by addind a small set of trainable parameters on top of the untrainable quantized ones.

In [19]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_prok", "k_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)
model.print_trainable_parameters()
model.enable_input_require_grads() # it's needed for gradient checkpointing

trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424


Here we add a trainable adapter ontop of every `q_prok`, `k_proj` and `o_proj` linear layer.

**Loading a dataset**

Let's load a common dataset, english quotes, to fine tune our model on famous quotes.

In [4]:
from datasets import load_dataset

raw_data = load_dataset("m-a-p/CodeFeedback-Filtered-Instruction")

def get_code(s):
    s = s.split("```")
    for p in s:
        if p.startswith('python'):
            return p[6:]

new_data = []
for x in raw_data['train']:

    code = get_code(x['answer'])
    if code is None or code=="None":
        continue
    
    inp = {"prompt": x["query"], "answer": str({"code": code})}
    new_data.append(inp)

Downloading readme:   0%|          | 0.00/2.93k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/371M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/156526 [00:00<?, ? examples/s]

In [7]:
from datasets import Dataset

tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
    inputs = tokenizer(examples["prompt"], max_length=512, padding=True, truncation=True, return_tensors="pt")
    targets = tokenizer(examples["answer"], max_length=1024, padding=True, truncation=True, return_tensors="pt")
    return {
        "input_ids": inputs.input_ids,
        "attention_mask": inputs.attention_mask,
        "labels": targets.input_ids,  # Set labels as the tokenized targets
    }

dataset = Dataset.from_list(new_data)
# dataset['prompt']
tokenized_dataset = dataset.map(tokenize_function, batched=True)

Map:   0%|          | 0/80121 [00:00<?, ? examples/s]

Run the cell below to run the training! For the sake of the demo, we just ran it for few steps just to showcase how to use this integration with existing tools on the HF ecosystem.

In [20]:
import transformers

tokenizer.pad_token = tokenizer.eos_token

trainer = transformers.Trainer(
    model=model,
    train_dataset=tokenized_dataset,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        gradient_checkpointing=True,
        warmup_steps=2,
        max_steps=10,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
        optim="adamw_bnb_8bit",
        logging_first_step=True,
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

max_steps is given, it will override any value given in num_train_epochs


Step,Training Loss
1,12.2426
2,17.1154
3,14.3796
4,12.1799
5,16.3675
6,15.1065
7,14.7574
8,12.1773
9,13.5281
10,11.9413



Cannot access gated repo for url https://huggingface.co/meta-llama/Meta-Llama-3-8B/resolve/main/config.json.
Access to model meta-llama/Meta-Llama-3-8B is restricted. You must be authenticated to access it. - silently ignoring the lookup for the file config.json in meta-llama/Meta-Llama-3-8B.


TrainOutput(global_step=10, training_loss=13.979569149017333, metrics={'train_runtime': 11.5547, 'train_samples_per_second': 3.462, 'train_steps_per_second': 0.865, 'total_flos': 922623903006720.0, 'train_loss': 13.979569149017333, 'epoch': 0.0004992448921006977})

In [None]:
### KV Cache

In [None]:
model.generation_config.cache_implementation = "static"
compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True)

In [None]:
  input_text = "Hi what is your name?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = compiled_model.generate(**input_ids)
tokenizer.batch_decode(outputs, skip_special_tokens=True)

In [None]:
prompts = [
    "Simply put, the theory of relativity states that ",
    "My favorite all time favorite condiment is ketchup.",
]
NUM_TOKENS_TO_GENERATE = 40
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
    logits = model(
        cur_token,
        position_ids=input_pos,
        cache_position=cache_position,
        past_key_values=past_key_values,
        return_dict=False,
        use_cache=True
    )[0]
    new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
    return new_token

from transformers import StaticCache


batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad():
    past_key_values = StaticCache(
        config=model.config, max_batch_size=2, max_cache_len=4096, device= "cuda", dtype=model.dtype
    )
    cache_position = torch.arange(seq_length, device="cuda")
    generated_ids = torch.zeros(
        batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device="cuda"
    )
    generated_ids[:, cache_position] = inputs["input_ids"].to("cuda").to(torch.int)

    logits = model(
        **inputs, cache_position=cache_position, past_key_values=past_key_values,return_dict=False, use_cache=True
    )[0]
    next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
    generated_ids[:, seq_length] = next_token[:, 0]

    decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
    cache_position = torch.tensor([seq_length + 1], device="cuda")
    for _ in range(1, NUM_TOKENS_TO_GENERATE):
        with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
            next_token = decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values)
            generated_ids[:, cache_position] = next_token.int()
        cache_position += 1

text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [None]:
text

['Simply put, the theory of relativity states that 1) the speed of light is constant, and 2) the speed of light is the fastest speed possible. The theory of relativity also states that the speed of light is not affected by the!',
 'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on eggs, burgers, fries, chicken, hot dogs, and even on my pizza. I love it so much that I have a whole drawer full of!']

In [None]:
# prompt: write inference code for the above model

# Generate text from the trained model
input_text = "hello what is your name"
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device)
output = model.generate(input_ids=input_ids, max_new_tokens=50)
print(tokenizer.decode(output[0], skip_special_tokens=True))


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


hello what is your name?'
    else:
        print('what is your name?')
    print('what is your name?')
    print('what is your name?')
    print('what is your name?')
    print('what is your name?')
   
