# Llama 2 7B

In [None]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"
MODEL_PATH = "/workspace/meta-llama/Llama-2-7b"

import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from transformers import LlamaTokenizer
from smooth import smooth_lm
from fake_quant import quantize_llama_like
import tqdm

class Evaluator:
    def __init__(self, dataset, tokenizer, device, n_samples=40):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.device = device

        self.dataset = tokenizer(
            "\n\n".join(dataset["text"]), return_tensors="pt"
        ).input_ids.to(device)

        self.n_samples = n_samples

    @torch.no_grad()
    def evaluate(self, model):
        model.eval()
        nlls = []
        for i in tqdm.tqdm(range(self.n_samples), desc="Evaluating..."):
            batch = self.dataset[:, (i * 2048) : ((i + 1) * 2048)].to(model.device)
            with torch.no_grad():
                lm_logits = model(batch).logits
            shift_logits = lm_logits[:, :-1, :].contiguous().float()
            shift_labels = self.dataset[:, (i * 2048) : ((i + 1) * 2048)][:, 1:]
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )
            neg_log_likelihood = loss.float() * 2048
            nlls.append(neg_log_likelihood)

        return torch.exp(torch.stack(nlls).sum() / (self.n_samples * 2048))

## Performing AIQ Quantization

In [None]:
model_fp16 = LlamaForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float16, device_map="cuda:0")
model_smoothquant = LlamaForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float16, device_map="cuda:1")
model_aiq = LlamaForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float16, device_map="cuda:2")
act_scales = torch.load("../act_scales/llama-2-7b.pt")
smooth_lm(model_smoothquant, act_scales, 0.85)
smooth_lm(model_aiq, act_scales, 0.85)

# model_smoothquant = quantize_llama_like(model_smoothquant, weight_quant='per_channel', act_quant='per_token', bits=(6,6))
model_smoothquant = quantize_llama_like(model_smoothquant, weight_quant='per_channel', act_quant='per_token', bits=(4,4))

In [None]:
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP

# Initialize an empty dictionary
val_dict = dict()

with torch.no_grad():
    for name, m_fp16 in model_fp16.named_modules():
        if name in dict(model_aiq.named_modules()):
            m_smooth = dict(model_aiq.named_modules())[name]

            if isinstance(m_fp16, LlamaAttention):
                print(f"Module: {name}")
                q = (m_fp16.q_proj.weight.abs().float().cpu() - m_smooth.q_proj.weight.abs().float().cpu()).abs().sum(dim=-1)
                k = (m_fp16.k_proj.weight.abs().float().cpu() - m_smooth.k_proj.weight.abs().float().cpu()).abs().sum(dim=-1)
                v = (m_fp16.v_proj.weight.abs().float().cpu() - m_smooth.v_proj.weight.abs().float().cpu()).abs().sum(dim=-1)
                val_dict[f"{name}.q_proj"] = q
                val_dict[f"{name}.k_proj"] = k
                val_dict[f"{name}.v_proj"] = v
            if isinstance(m_fp16, LlamaMLP):
                print(f"Module: {name}")
                g = (m_fp16.gate_proj.weight.abs().float().cpu() - m_smooth.gate_proj.weight.abs().float().cpu()).abs().sum(dim=-1)
                u = (m_fp16.up_proj.weight.abs().float().cpu() - m_smooth.up_proj.weight.abs().float().cpu()).abs().sum(dim=-1)
                d = (m_fp16.down_proj.weight.abs().float().cpu() - m_smooth.down_proj.weight.abs().float().cpu()).abs().sum(dim=-1)
                val_dict[f"{name}.gate_proj"] = g
                val_dict[f"{name}.up_proj"]   = u
                val_dict[f"{name}.down_proj"] = d

In [None]:
bit_dict = dict()

def create_bit_allocation_vector(x, threshold, bit):
    import numpy as np

    q = np.array(x)
    q_min = np.min(q)
    q_max = np.max(q)
    q_normalized = (q - q_min) / (q_max - q_min)
    mu, omega = q_normalized.mean(), q_normalized.std()
    return np.where(q_normalized < mu - omega * threshold, bit-1, np.where(q_normalized >  mu + omega * threshold, bit+2, bit))

with torch.no_grad():
    for key in val_dict:
        bit_dict[key] = create_bit_allocation_vector(val_dict[key], 2, 6)

In [None]:
import torch
from torch import nn
from functools import partial
from fake_quant import W8A8Linear, quantize_activation_per_token_absmax

@torch.no_grad()
def quantize_weight_per_channel_absmax_map(w, map):
    # w: (out_features, in_features)
    scales = w.abs().max(dim=-1, keepdim=True)[0]
    scales.clamp_(min=1e-5)
    for idx, bit in enumerate(map):
        scales[idx] /= 2 ** (bit - 1) - 1
    w.div_(scales).round_().mul_(scales)
    return w

class AIQLinear(nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        act_quant="per_token",
        quantize_output=False,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.register_buffer(
            "weight",
            torch.randn(
                self.out_features,
                self.in_features,
                dtype=torch.float16,
                requires_grad=False,
            ),
        )
        if bias:
            self.register_buffer(
                "bias",
                torch.zeros(
                    (1, self.out_features), dtype=torch.float16, requires_grad=False
                ),
            )
        else:
            self.register_buffer("bias", None)

        if act_quant == "per_token":
            self.act_quant_name = "per_token"
            self.act_quant = partial(quantize_activation_per_token_absmax, n_bits=6)
        else:
            raise ValueError(f"Invalid act_quant: {act_quant}")

        if quantize_output:
            self.output_quant_name = self.act_quant_name
            self.output_quant = self.act_quant
        else:
            self.output_quant_name = "None"
            self.output_quant = lambda x: x

    def to(self, *args, **kwargs):
        super(AIQLinear, self).to(*args, **kwargs)
        self.weight = self.weight.to(*args, **kwargs)
        if self.bias is not None:
            self.bias = self.bias.to(*args, **kwargs)
        return self

    @torch.no_grad()
    def forward(self, x):
        q_x = self.act_quant(x)
        y = torch.functional.F.linear(q_x, self.weight, self.bias)
        q_y = self.output_quant(y)
        return q_y

    @staticmethod
    def from_float(
        module, map, weight_quant="per_channel", act_quant="per_token", quantize_output=False
    ):
        assert isinstance(module, torch.nn.Linear)
        new_module = AIQLinear(
            module.in_features,
            module.out_features,
            module.bias is not None,
            act_quant=act_quant,
            quantize_output=quantize_output,
        )

        if weight_quant == "per_channel":
            new_module.weight = quantize_weight_per_channel_absmax_map(
                module.weight, map=map # weight bits from argument
            )
        else:
            raise ValueError(f"Invalid weight_quant: {weight_quant}")
        new_module.weight_quant_name = weight_quant
        if module.bias is not None:
            new_module.bias = module.bias
        return new_module

    def __repr__(self):
        return f"AIQLinear({self.in_features}, {self.out_features}, bias={self.bias is not None}, weight_quant={self.weight_quant_name}, act_quant={self.act_quant_name}, output_quant={self.output_quant_name}, bits=W{self.bits[0]}A{self.bits[1]})"


def quantize_aiq(
    model, weight_quant="per_channel", act_quant="per_token", quantize_bmm_input=False):
    from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP

    # simulating variable bit-width integer quantization
    for name, m in tqdm.tqdm(model.named_modules()):
        if isinstance(m, LlamaMLP):
            m.gate_proj = AIQLinear.from_float(
                m.gate_proj, weight_quant=weight_quant, act_quant=act_quant, map=bit_dict[f"{name}.gate_proj"]
            )
            m.up_proj = AIQLinear.from_float(
                m.up_proj, weight_quant=weight_quant, act_quant=act_quant, map=bit_dict[f"{name}.up_proj"]
            )
            m.down_proj = W8A8Linear.from_float(
                m.down_proj, weight_quant=weight_quant, act_quant=act_quant, bits=(8,8)
            )
        elif isinstance(m, LlamaAttention):
            m.q_proj = AIQLinear.from_float(
                m.q_proj,
                weight_quant=weight_quant,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
                map=bit_dict[f"{name}.q_proj"]
            )
            m.k_proj = AIQLinear.from_float(
                m.k_proj,
                weight_quant=weight_quant,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
                map=bit_dict[f"{name}.k_proj"]
            )
            m.v_proj = AIQLinear.from_float(
                m.v_proj,
                weight_quant=weight_quant,
                act_quant=act_quant,
                quantize_output=quantize_bmm_input,
                map=bit_dict[f"{name}.v_proj"]
            )
            m.o_proj = W8A8Linear.from_float(
                m.o_proj, weight_quant=weight_quant, act_quant=act_quant, bits=(6,6)
            )
    return model

model_aiq = quantize_aiq(model_aiq, weight_quant='per_channel', act_quant='per_token')

In [None]:
model_aiq.save_pretrained("llama-2-7b-aiq-w6a6")
# model_smoothquant.save_pretrained("llama-2-7b-smoothquant")

## Evaluation

In [None]:
from datasets import load_dataset

tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
# evaluator_0 = Evaluator(dataset, tokenizer, "cuda:0")
# evaluator_1 = Evaluator(dataset, tokenizer, "cuda:1")
evaluator_2 = Evaluator(dataset, tokenizer, "cuda:2")

In [None]:
ppl_fp16 = evaluator_0.evaluate(model_fp16)
print(f"Perplexity of FP16 model: {ppl_fp16}")

In [None]:
ppl_smoothquant = evaluator_1.evaluate(model_smoothquant)
print(f"SmoothQuant perplexity: {ppl_smoothquant}")

In [None]:
ppl_aiq = evaluator_2.evaluate(model_aiq)
print(f"AIQ perplexity: {ppl_aiq}")

In [None]:
val = list()
for key in bit_dict:
    val.append((bit_dict[key]-6).sum().item()*4096)
import numpy as np
np.array(val).sum()/8

In [None]:
del model_aiq
torch.cuda.empty_cache()