# Mixtral quantization

## Goal

Can I make Mixtral to work better if I don't quantize the gate layers?

## References

- [ ] What if MoE does not deal correctly with quantization and should I leave some layers as they are?
  - [ ] https://github.com/mobiusml/hqq/issues/2
  - [ ] https://github.com/vllm-project/vllm/issues/2243
  - [ ] I could try this on the forum fork
  - [ ] Here we see again the Mixtral gates suggestion, although it is not the bit and bytes config.
  - [ ] https://huggingface.co/docs/transformers/main_classes/quantization#transformers.QuantoConfig.modules_to_not_convert
  - [ ] https://huggingface.co/docs/transformers/main_classes/quantization#transformers.BitsAndBytesConfig
  - [ ] llm_int8_skip_modules might be used

## Imports

In [1]:
import torch
import gc
import time
import re
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas as pd
from tqdm.auto import tqdm
import yaml
import os
import hashlib

from transformers import (
    AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,
    pipeline, TrainingArguments
)
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from datasets import Dataset

plt.plot()
plt.close('all')
plt.rcParams["figure.figsize"] = (20, 5)
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 16

pd.set_option('display.max_colwidth', 200)

2024-04-09 10:07:57.697332: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-04-09 10:07:58.111997: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Load model

In [2]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.float16,
    bnb_4bit_use_double_quant= True,
    llm_int8_enable_fp32_cpu_offload= True,
    llm_int8_skip_modules=['gate', 'lm_head'],
    #llm_int8_skip_modules=['gate', 'lm_head', 'w1'],
)

torch.cuda.empty_cache()

In [3]:
auto_device_map = {
    'model.embed_tokens': 0,
    'model.layers.0': 0,
    'model.layers.1': 0,
    'model.layers.2': 0,
    'model.layers.3': 0,
    'model.layers.4': 0,
    'model.layers.5': 0,
    'model.layers.6': 0,
    'model.layers.7': 0,
    'model.layers.8': 0,
    'model.layers.9': 0,
    'model.layers.10': 0,
    'model.layers.11': 0,
    'model.layers.12': 0,
    'model.layers.13': 0,
    'model.layers.14': 1,
    'model.layers.15': 1,
    'model.layers.16': 1,
    'model.layers.17': 1,
    'model.layers.18': 1,
    'model.layers.19': 1,
    'model.layers.20': 1,
    'model.layers.21': 1,
    'model.layers.22': 1,
    'model.layers.23': 1,
    'model.layers.24': 1,
    'model.layers.25': 1,
    'model.layers.26': 1,
    'model.layers.27': 1,
    'model.layers.28': 1,
    'model.layers.29': 1,
    'model.layers.30': 1,
    'model.layers.31': 1,
    'model.norm': 1,
    'lm_head': 1
 }

def create_shared_device_map(transition_layer):
    shared_device_map = {}
    for idx, key in enumerate(auto_device_map):
        if idx <= transition_layer:
            shared_device_map[key] = 0
        else:
            shared_device_map[key] = 1
    return shared_device_map

def create_intertwined_device_map():
    device_map = {}
    for idx, key in enumerate(auto_device_map):
        if idx == 0:
            device_map[key] = 1
        elif idx >= 33:
            device_map[key] = 0
        else:
            device_map[key] = idx % 2
    return device_map

In [4]:
model_path = '/home/gbarbadillo/data/mixtral-8x7b-instruct-v0.1-hf/'
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    quantization_config=bnb_config,
    device_map=create_shared_device_map(16),
    trust_remote_code=True,
    )

Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

In [5]:
tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    trust_remote_code=True)

In [6]:
pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer)

def chat_with_mixtral(prompt, max_new_tokens=200, verbose=True, do_sample=False, temperature=0.7, top_p=0.95):
    if not prompt.startswith('<s>[INST]'):
        print('Formatting the prompt to Mixtral needs.')
        prompt = f'<s>[INST] {prompt} [/INST]'
    start = time.time()

    if do_sample:
        sampling_kwargs = dict(do_sample=True, temperature=temperature, top_p=top_p)
    else:
        sampling_kwargs = dict(do_sample=False)

    sequences = pipe(
        prompt ,
        max_new_tokens=max_new_tokens,
        # https://www.reddit.com/r/LocalLLaMA/comments/184g120/mistral_fine_tuning_eos_and_padding/
        # https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/discussions/106
        pad_token_id=tokenizer.eos_token_id,
        **sampling_kwargs,
        return_full_text=False,
    )
    response = sequences[0]['generated_text']
    #response = re.sub(r'[\'"]', '', response)
    if verbose:
        stop = time.time()
        time_taken = stop-start
        n_tokens = len(tokenizer.tokenize(response))
        print(f"Execution Time : {time_taken:.1f} s, tokens per second: {n_tokens/time_taken:.1f}")
    return response

In [7]:
def print_gpu_memory():
    for device in range(torch.cuda.device_count()):
        print(f'GPU {device} memory allocated: {torch.cuda.memory_allocated(device)/1024**3:.1f} GB, max memory allocated: {torch.cuda.max_memory_allocated(device)/1024**3:.1f} GB')

In [8]:
def print_parameters(model, key=None):
    for module_name, module in model.named_modules():
        if key is not None and key not in module_name:
            continue
        total_params = 0
        for param in module.parameters(recurse=False):
            # Multiply dimensions of the parameter
            total_params += param.numel()
        if total_params > 0:
            print(f"{module_name} ({type(module).__name__}): {total_params:.1e} parameters")

In [9]:
raise

RuntimeError: No active exception to reraise

## Baseline

In [None]:
print_gpu_memory()
model

GPU 0 memory allocated: 11.7 GB, max memory allocated: 11.7 GB
GPU 1 memory allocated: 11.7 GB, max memory allocated: 11.7 GB


MixtralForCausalLM(
  (model): MixtralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MixtralDecoderLayer(
        (self_attn): MixtralSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MixtralRotaryEmbedding()
        )
        (block_sparse_moe): MixtralSparseMoeBlock(
          (gate): Linear4bit(in_features=4096, out_features=8, bias=False)
          (experts): ModuleList(
            (0-7): 8 x MixtralBlockSparseTop2MLP(
              (w1): Linear4bit(in_features=4096, out_features=14336, bias=False)
              (w2): Linear4bit(in_features=14336, out_features=4096, bias=False)
              (w3): Linear4bit(in_features=4096, out_

I believe we should skip the quantization of the `(gate): Linear4bit(in_features=4096, out_features=8, bias=False)`

## Attempt to skip quantization of gate layers

In [None]:
print_gpu_memory()
model

GPU 0 memory allocated: 11.7 GB, max memory allocated: 11.7 GB
GPU 1 memory allocated: 11.5 GB, max memory allocated: 11.5 GB


MixtralForCausalLM(
  (model): MixtralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MixtralDecoderLayer(
        (self_attn): MixtralSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MixtralRotaryEmbedding()
        )
        (block_sparse_moe): MixtralSparseMoeBlock(
          (gate): Linear(in_features=4096, out_features=8, bias=False)
          (experts): ModuleList(
            (0-7): 8 x MixtralBlockSparseTop2MLP(
              (w1): Linear4bit(in_features=4096, out_features=14336, bias=False)
              (w2): Linear4bit(in_features=14336, out_features=4096, bias=False)
              (w3): Linear4bit(in_features=4096, out_feat

It seems to have worked! `(gate): Linear(in_features=4096, out_features=8, bias=False)`
The weird thing is that the used memory is almost the same.

In [None]:
print_parameters(model, 'gate')

model.layers.0.block_sparse_moe.gate (Linear): 3.3e+04 parameters
model.layers.1.block_sparse_moe.gate (Linear): 3.3e+04 parameters
model.layers.2.block_sparse_moe.gate (Linear): 3.3e+04 parameters
model.layers.3.block_sparse_moe.gate (Linear): 3.3e+04 parameters
model.layers.4.block_sparse_moe.gate (Linear): 3.3e+04 parameters
model.layers.5.block_sparse_moe.gate (Linear): 3.3e+04 parameters
model.layers.6.block_sparse_moe.gate (Linear): 3.3e+04 parameters
model.layers.7.block_sparse_moe.gate (Linear): 3.3e+04 parameters
model.layers.8.block_sparse_moe.gate (Linear): 3.3e+04 parameters
model.layers.9.block_sparse_moe.gate (Linear): 3.3e+04 parameters
model.layers.10.block_sparse_moe.gate (Linear): 3.3e+04 parameters
model.layers.11.block_sparse_moe.gate (Linear): 3.3e+04 parameters
model.layers.12.block_sparse_moe.gate (Linear): 3.3e+04 parameters
model.layers.13.block_sparse_moe.gate (Linear): 3.3e+04 parameters
model.layers.14.block_sparse_moe.gate (Linear): 3.3e+04 parameters
model

In [None]:
(3.3e+04*32)/1e6

1.056

So all the gate layers have around 1M of parameters, (32 layers with 33k parameters each), that is tiny compared to the whole model. That is we don't see a big change on GPU memory.

## Do not quantize the head also

In [None]:
print_gpu_memory()
model

GPU 0 memory allocated: 11.7 GB, max memory allocated: 11.7 GB
GPU 1 memory allocated: 11.7 GB, max memory allocated: 11.7 GB


MixtralForCausalLM(
  (model): MixtralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MixtralDecoderLayer(
        (self_attn): MixtralSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MixtralRotaryEmbedding()
        )
        (block_sparse_moe): MixtralSparseMoeBlock(
          (gate): Linear(in_features=4096, out_features=8, bias=False)
          (experts): ModuleList(
            (0-7): 8 x MixtralBlockSparseTop2MLP(
              (w1): Linear4bit(in_features=4096, out_features=14336, bias=False)
              (w2): Linear4bit(in_features=14336, out_features=4096, bias=False)
              (w3): Linear4bit(in_features=4096, out_feat

In [None]:
print_parameters(model, 'lm_head')

lm_head (Linear): 1.3e+08 parameters


The head has around 1e8 parameters, that will take 0.2GB of memory when using float16 and 0.05 when using float4. Thus the difference is just 0.15GB

## Do not quantize some heavy layer

In [None]:
2.9e+07*32*8/1e9

7.424

In [None]:
print_parameters(model, 'w1')

model.layers.0.block_sparse_moe.experts.0.w1 (Linear4bit): 2.9e+07 parameters
model.layers.0.block_sparse_moe.experts.1.w1 (Linear4bit): 2.9e+07 parameters
model.layers.0.block_sparse_moe.experts.2.w1 (Linear4bit): 2.9e+07 parameters
model.layers.0.block_sparse_moe.experts.3.w1 (Linear4bit): 2.9e+07 parameters
model.layers.0.block_sparse_moe.experts.4.w1 (Linear4bit): 2.9e+07 parameters
model.layers.0.block_sparse_moe.experts.5.w1 (Linear4bit): 2.9e+07 parameters
model.layers.0.block_sparse_moe.experts.6.w1 (Linear4bit): 2.9e+07 parameters
model.layers.0.block_sparse_moe.experts.7.w1 (Linear4bit): 2.9e+07 parameters
model.layers.1.block_sparse_moe.experts.0.w1 (Linear4bit): 2.9e+07 parameters
model.layers.1.block_sparse_moe.experts.1.w1 (Linear4bit): 2.9e+07 parameters
model.layers.1.block_sparse_moe.experts.2.w1 (Linear4bit): 2.9e+07 parameters
model.layers.1.block_sparse_moe.experts.3.w1 (Linear4bit): 2.9e+07 parameters
model.layers.1.block_sparse_moe.experts.4.w1 (Linear4bit): 2.9e+

In [10]:
print_gpu_memory()
model

GPU 0 memory allocated: 22.0 GB, max memory allocated: 22.0 GB
GPU 1 memory allocated: 22.0 GB, max memory allocated: 22.0 GB


MixtralForCausalLM(
  (model): MixtralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MixtralDecoderLayer(
        (self_attn): MixtralSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MixtralRotaryEmbedding()
        )
        (block_sparse_moe): MixtralSparseMoeBlock(
          (gate): Linear(in_features=4096, out_features=8, bias=False)
          (experts): ModuleList(
            (0-7): 8 x MixtralBlockSparseTop2MLP(
              (w1): Linear(in_features=4096, out_features=14336, bias=False)
              (w2): Linear4bit(in_features=14336, out_features=4096, bias=False)
              (w3): Linear4bit(in_features=4096, out_features

This works correctly! 

## Summary

We have verified that it is possible to skip the quantization of some layers.

If we don't quantize the gates and the lm_head of Mixtral the memory usage is almost the same since they are small layers.

Just add `llm_int8_skip_modules=['gate', 'lm_head'],` to the configuration.