# Part1: Quantize Llama-3.2-3B-Instruct

In [None]:
!pip3 install huggingface-hub[cli]
!pip3 install transformers==4.50.3
!pip3 install torch torchvision torchaudio
!pip3 install timm==1.0.15
!pip3 install datasets==3.5.0
!pip3 install accelerate==1.6.0
!pip3 install gemlite==0.4.4
!pip3 install hqq==0.2.5
!pip3 install triton==3.2.0

In [None]:
import torch
import torch.nn as nn
from torch import float16
from tqdm.auto import tqdm

from typing import Union, Callable
from functools import partial
import json
import timm
import os

from hqq.core.quantize import HQQLinear
from hqq.core.utils import cleanup

from hqq.models.base import (
    forward_device_hooked,
    get_all_children_from_model,
    find_parent,
    is_leaf_module,
    BaseHQQModel,
    BasePatch
)

from hqq.models.hf.base import BaseHQQHFModel

_QUANT_LAYERS = [nn.Linear, HQQLinear]
_IGNORE_LINEAR = ['lm_head']

def get_size_of_model(model):
    size_in_bytes = 0
    for _, module in model.named_modules():
        if isinstance(module, HQQLinear):
            # W_q / Scale / Zero / Bias
            size_in_bytes += module.W_q.numel() * module.W_q.element_size()
            size_in_bytes += module.meta['scale'].numel() * module.meta['scale'].element_size()
            size_in_bytes += module.meta['zero'].numel() * module.meta['zero'].element_size()

            if isinstance(getattr(module, 'bias'), torch.Tensor):
                size_in_bytes += module.bias.numel() * module.bias.element_size()

        elif is_leaf_module(module):
            for param in module.parameters():
                size_in_bytes += param.numel() * param.element_size()
            for buffer in module.buffers():
                size_in_bytes += buffer.numel() * buffer.element_size()

    return size_in_bytes

# Get all linear tags available
def get_linear_tags_from_model(model, ignore: list) -> list:
    linear_tags = set()
    for name, module in model.named_modules():
        if (type(module) in _QUANT_LAYERS) and (name.split(".")[-1] not in ignore):
            linear_tags.add(name)
    return list(linear_tags)

class CustomPatch(BasePatch):
    # This method iterates through layers of the model that are nn.Linear and processes them via new_nodule = patch_fct(module, params)
    @classmethod
    def patch_linearlayers(
        cls,
        model,
        patch_fct: Callable,
        patch_params: Union[dict, None],
        verbose: bool = True,
    ) -> None:
        ignore_tags = cls.get_ignore_layers(model)

        tmp_mapping = {}
        for name, module in model.named_modules():
            if (type(module) in _QUANT_LAYERS) and (name not in ignore_tags):
                tmp_mapping[name] = module

        for name in tqdm(tmp_mapping, disable=not verbose):
            linear_tag = name
            patch_param = (
                patch_params[linear_tag] if (linear_tag in patch_params) else None
            )
            setattr(
                find_parent(model, name),
                name.split(".")[-1],
                patch_fct(tmp_mapping[name], patch_param),
            )

        cleanup()

        # These tags are used to specfiy parameters of the patching in patch_linearlayers()
    @classmethod
    def set_auto_linear_tags(cls, model, ignore: list = _IGNORE_LINEAR) -> None:
        if hasattr(model, "linear_tags") is False:
            linear_tags = cls.get_linear_tags()
            model.linear_tags = (
                linear_tags
                if len(linear_tags) > 0
                else get_linear_tags_from_model(model, ignore=ignore)
            )
            model.base_class = cls

class CustomHQQTimmModel(BaseHQQModel):
    # Create empty model
    @classmethod
    def create_model(cls, save_dir, kwargs):
        with open(cls.get_config_file(save_dir), "r") as file:
            config = json.load(file)
        model = timm.create_model(
            config["architecture"] + "." + config["tag"], pretrained=False
        )
        return model

    # Save model architecture
    @classmethod
    def cache_model(cls, model, save_dir):
        try:
            os.makedirs(save_dir, exist_ok=True)
        except Exception as error:
            print(error)

        with open(cls.get_config_file(save_dir), "w") as file:
            json.dump(model.default_cfg, file)

    # Main function to quantize a model. Basically goes through the linear layers specfied in the patching function and replaces them with HQQLinear
    @classmethod
    def quantize_model(
        cls,
        model,
        quant_config: dict,
        compute_dtype: torch.dtype = float16,
        device: Union[str, list, dict] = "cuda",
    ):
        # Check if the model was already quantized
        if getattr(model, "hqq_quantized", False):
            print("Model was already quantized")
            return

        # Set linear tags automatically
        cls.setup_model(model)

        # Use the same quantization config for all linear layers. Use None to skip quantizing a specfic layer.
        if True in [(key in model.linear_tags) for key in quant_config.keys()]:
            # If the user doesn't specify a key from get_linear_tags, the layer is not quantized via (key, None)
            patch_params = {key: None for key in model.linear_tags}
            patch_params.update(quant_config)
        elif quant_config == {}:
            patch_params = {key: None for key in model.linear_tags}
        else:
            # Same quant_config for all layers
            patch_params = {k: quant_config for k in model.linear_tags}

        # Get list of all nodes in order
        all_nodes = get_all_children_from_model(model, [])  # ordered nodes
        try:
            # Extract block names: This is following Hugging Face models.
            num_blocks = (
                len(model.model.blocks)   # TODO: Modify layers to blocks
                if hasattr(model, "model")
                else len(model.blocks)
            )
            all_blocks = ["blocks." + str(i) for i in range(num_blocks)]
        except Exception:
            all_blocks = None
            print(
                "Default model structure not supported. Make sure you feed device as dictionary as {name_block: device}"
            )

        if isinstance(
            device, dict
        ):  # input as {module block name (str): device (str or torch.device)}
            device_map = device
            num_devices = len(set([device_map[k] for k in device_map]))
            all_blocks = list(device_map.keys())

        node_to_block = {}
        for node in all_nodes:
            res = [block for block in all_blocks if (block in node)]
            node_to_block[node] = res[-1] if (len(res) > 0) else node

        # Set device-map
        if isinstance(device, str):  # single device as str
            device_map = {k: device for k in all_blocks + all_nodes}
            num_devices = 1

        if isinstance(device, list):  # list of devices
            num_devices = len(device)
            device_map = {}
            for node in all_nodes:
                if ".blocks" in node:
                    break
                device_map[node] = device[0]

            for node in all_nodes[::-1]:
                if ".blocks" in node:
                    break
                device_map[node] = device[-1]

            step, k = len(all_blocks) // num_devices, 0
            for i in range(0, len(all_blocks), step):
                for j in range(i, i + step):
                    device_map[all_blocks[min(j, len(all_blocks) - 1)]] = device[
                        min(k, num_devices - 1)
                    ]
                k += 1

        # Map nodes to block devices
        for node in all_nodes:
            device_map[node] = device_map[node_to_block[node]]

        # print(device_map)

        # We replace the nn.Linear layers with HQQLinear
        def _patch_linear(linear_layer, quant_config):
            if type(linear_layer) is HQQLinear:
                return linear_layer

            current_device = device_map[linear_layer.name]
            if quant_config is not None:
                out_module = HQQLinear(
                    linear_layer,
                    quant_config,
                    compute_dtype=compute_dtype,
                    device=current_device,
                )
            else:
                out_module = linear_layer.to(device=current_device, dtype=compute_dtype)

            out_module.device = current_device
            return out_module

        def _patch_other(layer):
            current_device = device_map[layer.name]
            layer.device = current_device
            return layer.to(device=current_device, dtype=compute_dtype)

        cls.patch_model(model, _patch_other, _patch_linear, patch_params)

        # Insert device switcher
        if num_devices > 1:
            core_model = model if hasattr(model, "blocks") else model.model

            # Make sure the input (first node) has the input in the right device during generation
            input_node_child_name = all_nodes[0].split(".")[-1]
            input_node = getattr(core_model, input_node_child_name)
            input_node.device = device_map[all_nodes[0]]
            input_node.forward_orig = input_node.forward
            input_node.forward = partial(forward_device_hooked, input_node)
            setattr(core_model, input_node_child_name, input_node)

            # Make sure all inputs to the blocks are in the right device
            for i in range(len(core_model.blocks)):
                core_model.blocks[i].device = device_map[core_model.blocks[i].name]
                core_model.blocks[i].forward_orig = core_model.blocks[i].forward
                core_model.blocks[i].forward = partial(
                    forward_device_hooked, core_model.blocks[i]
                )

        # Set base class
        model.base_class = cls

        model.hqq_quantized = True

        return model

class CustomHQQHFModel(BaseHQQHFModel):
    # Main function to quantize a model. Basically goes through the linear layers specfied in the patching function and replaces them with HQQLinear
    @classmethod
    def quantize_model(
        cls,
        model,
        quant_config: dict,
        compute_dtype: torch.dtype = float16,
        device: Union[str, list, dict] = "cuda",
    ):
        # Check if the model was already quantized
        if getattr(model, "hqq_quantized", False):
            print("Model was already quantized")
            return

        # Set linear tags automatically
        cls.setup_model(model)

        # Use the same quantization config for all linear layers. Use None to skip quantizing a specfic layer.
        if True in [(key in model.linear_tags) for key in quant_config.keys()]:
            # If the user doesn't specify a key from get_linear_tags, the layer is not quantized via (key, None)
            patch_params = {key: None for key in model.linear_tags}
            patch_params.update(quant_config)
        elif quant_config == {}:
            patch_params = {key: None for key in model.linear_tags}
        else:
            # Same quant_config for all layers
            patch_params = {k: quant_config for k in model.linear_tags}

        # Get list of all nodes in order
        all_nodes = get_all_children_from_model(model, [])  # ordered nodes
        try:
            # Extract block names: This is following Hugging Face models.
            num_blocks = (
                len(model.model.layers)
                if hasattr(model, "model")
                else len(model.layers)
            )
            all_blocks = ["model.layers." + str(i) for i in range(num_blocks)]
        except Exception:
            all_blocks = None
            print(
                "Default model structure not supported. Make sure you feed device as dictionary as {name_block: device}"
            )

        if isinstance(
            device, dict
        ):  # input as {module block name (str): device (str or torch.device)}
            device_map = device
            num_devices = len(set([device_map[k] for k in device_map]))
            all_blocks = list(device_map.keys())

        node_to_block = {}
        for node in all_nodes:
            res = [block for block in all_blocks if (block in node)]
            node_to_block[node] = res[-1] if (len(res) > 0) else node

        # Set device-map
        if isinstance(device, str):  # single device as str
            device_map = {k: device for k in all_blocks + all_nodes}
            num_devices = 1

        if isinstance(device, list):  # list of devices
            num_devices = len(device)
            device_map = {}
            for node in all_nodes:
                if ".layers" in node:
                    break
                device_map[node] = device[0]

            for node in all_nodes[::-1]:
                if ".layers" in node:
                    break
                device_map[node] = device[-1]

            step, k = len(all_blocks) // num_devices, 0
            for i in range(0, len(all_blocks), step):
                for j in range(i, i + step):
                    device_map[all_blocks[min(j, len(all_blocks) - 1)]] = device[
                        min(k, num_devices - 1)
                    ]
                k += 1

        # Map nodes to block devices
        for node in all_nodes:
            device_map[node] = device_map[node_to_block[node]]

        # print(device_map)

        # We replace the nn.Linear layers with HQQLinear
        def _patch_linear(linear_layer, quant_config):
            if type(linear_layer) is HQQLinear:
                return linear_layer

            current_device = device_map[linear_layer.name]
            if quant_config is not None:
                out_module = HQQLinear(
                    linear_layer,
                    quant_config,
                    compute_dtype=compute_dtype,
                    device=current_device,
                )
            else:
                out_module = linear_layer.to(device=current_device, dtype=compute_dtype)

            out_module.device = current_device
            return out_module

        def _patch_other(layer):
            current_device = device_map[layer.name]
            layer.device = current_device
            return layer.to(device=current_device, dtype=compute_dtype)

        cls.patch_model(model, _patch_other, _patch_linear, patch_params)

        # Insert device switcher
        if num_devices > 1:
            core_model = model if hasattr(model, "layers") else model.model

            # Make sure the input (first node) has the input in the right device during generation
            input_node_child_name = all_nodes[0].split(".")[-1]
            input_node = getattr(core_model, input_node_child_name)
            input_node.device = device_map[all_nodes[0]]
            input_node.forward_orig = input_node.forward
            input_node.forward = partial(forward_device_hooked, input_node)
            setattr(core_model, input_node_child_name, input_node)

            # Make sure all inputs to the blocks are in the right device
            for i in range(len(core_model.layers)):
                core_model.layers[i].device = device_map[core_model.layers[i].name]
                core_model.layers[i].forward_orig = core_model.layers[i].forward
                core_model.layers[i].forward = partial(
                    forward_device_hooked, core_model.layers[i]
                )

        # Set base class
        model.base_class = cls

        model.hqq_quantized = True

        return model

# Auto class used for HF models if no architecture was manually setup
class AutoHQQHFModel(CustomHQQHFModel, CustomPatch):
    pass

class AutoHQQTimmModel(CustomHQQTimmModel, CustomPatch):
    pass

In [None]:
import numpy as np
import torch
import os
from tqdm.auto import tqdm

from torchvision import datasets, transforms
from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torch.utils.data import DataLoader

def build_transform(is_train):
    input_size = 224
    eval_crop_ratio = 1.0

    resize_im = input_size > 32
    if is_train:
        # this should always dispatch to transforms_imagenet_train
        transform = create_transform(
            input_size=input_size,
            is_training=True,
            color_jitter=0.3,
            auto_augment='rand-m9-mstd0.5-inc1',
            interpolation='bicubic',
            re_prob=0.0,
            re_mode='pixel',
            re_count=1,
        )
        if not resize_im:
            # replace RandomResizedCropAndInterpolation with
            # RandomCrop
            transform.transforms[0] = transforms.RandomCrop(
                input_size, padding=4)
        return transform

    t = []
    if resize_im:
        size = int(input_size / eval_crop_ratio)
        t.append(
            transforms.Resize(size, interpolation=3),  # to maintain same ratio w.r.t. 224 images
        )
        t.append(transforms.CenterCrop(input_size))

    t.append(transforms.ToTensor())
    t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
    return transforms.Compose(t)

def build_dataset_CIFAR100(is_train, data_path):
    transform = build_transform(is_train)
    dataset = datasets.CIFAR100(data_path, train=is_train, transform=transform, download=True)
    nb_classes = 100
    return dataset, nb_classes

def prepare_data(batch_size):
    train_set, nb_classes = build_dataset_CIFAR100(is_train=True, data_path='./data')
    test_set, _ = build_dataset_CIFAR100(is_train=False, data_path='./data')

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, drop_last=True)
    return train_loader, test_loader, nb_classes

def evaluate_model(model, data_loader, device):
    model.to(device)
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(data_loader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    # print(f'Accuracy of the model on the test images: {accuracy}%')
    return accuracy


In [None]:
from hqq.core.quantize import BaseQuantizeConfig
from huggingface_hub import login
#or just keyin your token here
login()

In [None]:
import gc
import torch
gc.collect()
torch.cuda.empty_cache() 

In [None]:
# Throughput: 60.09128305820182 toks/s
# Perplexity (PPL): 11.403003692626953
def get_quant_config_slm(model):
    quant_config = {}

    n_layers = model.config.num_hidden_layers

    # 定義不同精度等級的量化參數
    q_very_light = BaseQuantizeConfig(nbits=4, group_size=32)
    q_base       = BaseQuantizeConfig(nbits=4, group_size=64)
    q_down_light = BaseQuantizeConfig(nbits=4, group_size=128)  # 專供 down_proj 用
    q_important  = BaseQuantizeConfig(nbits=8, group_size=64)
    q_important_middle = BaseQuantizeConfig(nbits=8, group_size=128)
    q_important_light = BaseQuantizeConfig(nbits=8, group_size=256)

    for i in range(n_layers):

        quant_config[f'model.layers.{i}.self_attn.q_proj'] = q_very_light

        quant_config[f'model.layers.{i}.self_attn.k_proj'] = q_important_light
        quant_config[f'model.layers.{i}.self_attn.v_proj'] = q_important_light

        quant_config[f'model.layers.{i}.self_attn.o_proj'] = q_important_middle

        quant_config[f'model.layers.{i}.mlp.gate_proj'] = q_down_light
        quant_config[f'model.layers.{i}.mlp.up_proj'] = q_down_light

        quant_config[f'model.layers.{i}.mlp.down_proj'] = q_down_light

    return quant_config

In [None]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
from tqdm.auto import tqdm
from datasets import load_dataset
import random
import numpy as np

from hqq.utils.patching import recommended_inductor_config_setter

def generate(model, input_ids, past_key_values, max_new_tokens, activate_timing, verbose=True):
    input_ids = input_ids.clone()
    tput = None
    # Run an initial forward pass to compute and store the static KV cache
    if verbose:
        print('Prefilling...')
    with torch.no_grad():
        # outputs = custom_forward(model, input_ids, past_key_values=past_key_values, use_cache=True, position_ids=None, attention_mask=None, cache_position=None, is_compiled=False)
        outputs = model.prefill_forward(input_ids, past_key_values=past_key_values, position_ids=None, attention_mask=None, cache_position=None, logits_to_keep=1)
        past_key_values = outputs.past_key_values
        next_token = torch.argmax(outputs.logits, dim=-1)
        input_ids = torch.cat([input_ids, next_token], dim=-1)

    # Generate tokens one by one using a for loop and update the KV cache
    if verbose:
        print('Decoding...')
    with torch.no_grad():
        if activate_timing:
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)
            start_event.record()
        for _ in range(max_new_tokens):
            # Compute position_ids using the current sequence length
            pos = input_ids.shape[1]
            cache_position = torch.arange(pos, pos+1, device=input_ids.device, dtype=torch.long)

            # Run the model on the last token using the cached key-value pairs
            outputs = model(
                next_token,
                past_key_values=past_key_values,
                position_ids=cache_position.unsqueeze(0),
                cache_position=cache_position
            )
            logits = outputs.logits

            # Greedily select the token with the highest probability
            next_token = torch.argmax(logits, dim=-1)

            # Append the predicted token to the generated sequence
            input_ids = torch.cat([input_ids, next_token], dim=-1)

            # Update the KV cache for the next iteration
            past_key_values = outputs.past_key_values
        if activate_timing:
            end_event.record()
        torch.cuda.synchronize()
    if activate_timing:
        tput = max_new_tokens / start_event.elapsed_time(end_event) * 1000
        # print(f"Throughput: {tput} toks/sec")
    return input_ids, tput

def evaluate_ppl(model, tokenizer, device="cuda:0"):
    test_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    # print(f"Dataset length: {len(test_dataset)}")

    test_enc = tokenizer("\n\n".join(test_dataset["text"]), return_tensors="pt")
    model.seqlen = 2048
    test_enc = test_enc.input_ids.to(device)

    nsamples = test_enc.numel() // model.seqlen
    nlls = []
    for i in tqdm(range(nsamples), desc="Evaluating..."):
        batch = test_enc[:, (i * model.seqlen):((i + 1) * model.seqlen)]

        with torch.no_grad():
            lm_logits = model(batch).logits

        shift_logits = lm_logits[:, :-1, :].contiguous().float()
        shift_labels = test_enc[:, (i * model.seqlen):((i + 1) * model.seqlen)][:, 1:]

        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        neg_log_likelihood = loss.float() * model.seqlen
        nlls.append(neg_log_likelihood)

    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))

    return ppl.item()

In [None]:
torch._dynamo.reset()
############## Set Up ##############
torch.manual_seed(0)
random.seed(0)
recommended_inductor_config_setter()

max_new_tokens = 256    # Number of new tokens to generate
device = 'cuda:0'
backend = 'gemlite'

model_name = "meta-llama/Llama-3.2-3B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map=device,
    token = "YOUR_TOKEN"
)
torch.compile(model)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Separate Prefill & Decode Forwarding Function
model.prefill_forward = model.forward
model.forward = torch.compile(model.forward, mode='max-autotune', dynamic=False, fullgraph=True)

print(f'Model Size Before Quant: {get_size_of_model(model) / (1024 ** 2)} MiB')

# TODO: Quantize
quant_config = get_quant_config_slm(model)

AutoHQQHFModel.quantize_model(model, quant_config=quant_config, compute_dtype=torch.float16, device=device)

save_dir = "/kaggle/working/hqq_Llama3.2-3B-Instruct"
AutoHQQHFModel.save_quantized(model, save_dir)

# Part2: Optimized Inference/Prediction/Eval

In [None]:
from hqq.utils.patching import prepare_for_inference
prepare_for_inference(model, backend=backend)
torch.cuda.empty_cache()

In [None]:
warmup_prompt = "Explain what AI is."
input_ids = tokenizer(warmup_prompt, return_tensors="pt").input_ids.to(device)
past_key_values = StaticCache(
    config=model.config,
    max_batch_size=1,
    max_cache_len=max_new_tokens + 16,
    device=model.device,
    dtype=torch.float16
)
for i in tqdm(range(5), desc="Warm Up..."):
    generated = generate(model, input_ids, past_key_values, max_new_tokens, activate_timing=False, verbose=False)
    past_key_values.reset()

prompt = "How to learn a new language?"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)

tputs = []
for _ in tqdm(range(10), desc="Test Inference"):
    generated, tput = generate(model, input_ids, past_key_values, max_new_tokens, activate_timing=True, verbose=False)
    past_key_values.reset()
    tputs.append(tput)
response = tokenizer.decode(generated[0][input_ids.shape[1]:], skip_special_tokens=True)
tputs = np.sort(tputs)[2:-2]
quant_tput = np.mean(tputs)
print(f'Prompt: {prompt}\nResponse: {response}\nThroughput: {quant_tput} toks/s')

print(f'Model Size After Quant: {get_size_of_model(model) / (1024 ** 2)} MiB')

ppl = evaluate_ppl(model, tokenizer, device)
print(f"Perplexity (PPL): {ppl}")
# print(f"Speedup: {quant_tput / org_tput} x")

score = 0
score += 10 if quant_tput >= 31.0 else 0
score += 30 if quant_tput >= 54.0 else 0
if ppl > 11.5:
    score = 0 
print(f'Score: {score}')

# torch.save(model.state_dict(), "/kaggle/working/llama3_quantized.pth")
save_path = "/kaggle/working/quantized_Llama3.2-3B-Instruct"
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
print(f"Quantized model saved to {save_path}")

# Save results to CSV
import csv
rounded_tput = round(quant_tput, 1)
ppl = round(ppl, 2)

with open("result.csv", mode="w", newline="") as file:
    writer = csv.writer(file)
    writer.writerow(["Id", "value"])
    writer.writerow([0, ppl])
    writer.writerow([1, rounded_tput])
