# SmoothQuant on Llama 3.2 1B & 3B

In this notebook, we use Llama-3.2 models to demonstrate SmoothQuant can use 8-bit for both weights and activations to achieve the similar perplexity as FP16 models.

In order to run this notebook, you need to install the following packages:

- smoothquant
- PyTorch
- Transformers
- Accelerate

In [1]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import (
    LlamaAttention,
    LlamaDecoderLayer,
    LlamaForCausalLM,
    LlamaMLP,
)
from transformers import AutoTokenizer
from smoothquant.smooth import smooth_lm
from smoothquant.fake_quant import quantize_llama_like
import tqdm

  from .autonotebook import tqdm as notebook_tqdm


The following is an evaluator to see the performance of the model. We use a toy dataset (the first 40 examples in the test set of the Wikitext-2 dataset) to evaluate the model. You can replace it with your own dataset. The conclusion should be the same.

In [2]:
class WikitextEvaluator:
    def __init__(self, dataset, tokenizer, device, n_samples=40):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.device = device

        self.dataset = tokenizer(
            "\n\n".join(dataset["text"]), return_tensors="pt"
        ).input_ids.to(device)

        self.n_samples = n_samples

    @torch.no_grad()
    def evaluate(self, model):
        model.eval()
        nlls = []
        for i in tqdm.tqdm(range(self.n_samples), desc="Evaluating..."):
            batch = self.dataset[:, (i * 2048) : ((i + 1) * 2048)].to(model.device)
            with torch.no_grad():
                lm_logits = model(batch).logits
            shift_logits = lm_logits[:, :-1, :].contiguous().float()
            shift_labels = self.dataset[:, (i * 2048) : ((i + 1) * 2048)][:, 1:]
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )
            neg_log_likelihood = loss.float() * 2048
            nlls.append(neg_log_likelihood)

        return torch.exp(torch.stack(nlls).sum() / (self.n_samples * 2048))
    
class LambadaEvaluator:
    def __init__(self, dataset, tokenizer, device):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.device = device

        # tokenize the dataset
        def tokenize_function(examples):
            example = self.tokenizer(examples["text"])
            return example

        self.dataset = self.dataset.map(tokenize_function, batched=True)
        self.dataset.set_format(type="torch", columns=["input_ids"])

    @torch.no_grad()
    def evaluate(self, model):
        model.eval()
        # The task is to predict the last word of the input.
        total, hit = 0, 0
        for batch in self.dataset:
            input_ids = batch["input_ids"].to(self.device).unsqueeze(0)
            label = input_ids[:, -1]
            outputs = model(input_ids)
            last_token_logits = outputs.logits[:, -2, :]
            pred = last_token_logits.argmax(dim=-1)
            total += label.size(0)
            hit += (pred == label).sum().item()
        acc = hit / total
        return acc

In [3]:
from datasets import load_dataset

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
dataset = load_dataset("lambada", split="validation[:100]")
lambada_evaluator = LambadaEvaluator(dataset, tokenizer, "cuda")
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
wikitext_evaluator = WikitextEvaluator(dataset, tokenizer, "cuda")

Token indices sequence length is longer than the specified maximum sequence length for this model (289077 > 131072). Running this sequence through the model will result in indexing errors


## FP16 Model Perplexity

Let's first check the performance of the original FP16 model.

In [4]:
model_fp16 = LlamaForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B", torch_dtype=torch.float16, device_map="auto"
)

In [5]:
print("\nLAMBADA Accuracy Evaluation")
acc_original = lambada_evaluator.evaluate(model_fp16)
print(f"accuracy on LAMBADA: {acc_original}")

print("\nPerplexity Evaluation on WikiText")
pp_wikitext = wikitext_evaluator.evaluate(model_fp16)
print(f'perplexity on wikitext: {pp_wikitext}')


LAMBADA Accuracy Evaluation
accuracy on LAMBADA: 0.81

Perplexity Evaluation on WikiText


Evaluating...: 100%|██████████| 40/40 [00:04<00:00,  8.90it/s]


perplexity on wikitext: 9.292794227600098


We then quantize the model to W8A8 and check the performance.

## Naive W8A8 Quantized Model Perplexity

In [6]:
model_w8a8 = quantize_llama_like(model_fp16)
print(model_w8a8)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): W8A8Linear(2048, 2048, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (k_proj): W8A8Linear(2048, 512, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (v_proj): W8A8Linear(2048, 512, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (o_proj): W8A8Linear(2048, 2048, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): W8A8Linear(2048, 8192, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (up_proj): W8A8Linear(2048, 8192, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)


In [7]:
print("\nLAMBADA Accuracy Evaluation")
acc_original = lambada_evaluator.evaluate(model_w8a8)
print(f"accuracy on LAMBADA: {acc_original}")

print("\nPerplexity Evaluation on WikiText")
pp_wikitext = wikitext_evaluator.evaluate(model_w8a8)
print(f'perplexity on wikitext: {pp_wikitext}')


LAMBADA Accuracy Evaluation
accuracy on LAMBADA: 0.8

Perplexity Evaluation on WikiText


Evaluating...: 100%|██████████| 40/40 [00:05<00:00,  7.77it/s]


perplexity on wikitext: 9.416200637817383


We can see there is a perplexity increase. We then use SmoothQuant to quantize the model and check the performance.

## SmoothQuant W8A8 Quantized Model Perplexity

In [9]:
model = LlamaForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B", torch_dtype=torch.float16, device_map="auto"
)
act_scales = torch.load("./act_scales/llama-3.2-1B.pt")
smooth_lm(model, act_scales, 0.85)
model_smoothquant_w8a8 = quantize_llama_like(model)
print(model_smoothquant_w8a8)

  act_scales = torch.load("./act_scales/llama-3.2-1B.pt")


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): W8A8Linear(2048, 2048, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (k_proj): W8A8Linear(2048, 512, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (v_proj): W8A8Linear(2048, 512, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (o_proj): W8A8Linear(2048, 2048, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): W8A8Linear(2048, 8192, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
          (up_proj): W8A8Linear(2048, 8192, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)


We can see the smoothed model has a lower perplexity which is close to the FP16 model's. This is because SmoothQuant smooths the outliers in activations and balances the quantization difficulty of activations and weights.

In [10]:
print("\nLAMBADA Accuracy Evaluation")
acc_original = lambada_evaluator.evaluate(model_smoothquant_w8a8)
print(f"accuracy on LAMBADA: {acc_original}")

print("\nPerplexity Evaluation on WikiText")
pp_wikitext = wikitext_evaluator.evaluate(model_smoothquant_w8a8)
print(f'perplexity on wikitext: {pp_wikitext}')


LAMBADA Accuracy Evaluation
accuracy on LAMBADA: 0.8

Perplexity Evaluation on WikiText


Evaluating...: 100%|██████████| 40/40 [00:05<00:00,  7.55it/s]


perplexity on wikitext: 9.413980484008789
