# SmoothQuant on Mamba

### Original Authors: Guangxuan Xiao\*, Ji Lin\*, Mickael Seznec, Julien Demouth, Song Han

### Adapted to Mamba by: Mike Qu

In this notebook, we use OPT-13B model to demonstrate SmoothQuant can use 8-bit for both weights and activations to achieve the same accuracy as FP16 models. Unlike previous method [[Dettmers *et al.*, 2022]](https://arxiv.org/abs/2208.07339), SmoothQuant enables fully INT8 GEMMs for linear layers and does not require high precision numbers to represent outliers. 

This notebook demonstrates SmoothQuant on OPT-13B in consideration of the user's resouce constraints. We have tested SmoothQuant on up to 176 billion parameter models (OPT-175B, BLOOM-176B, GLM-130B). You can also adjust the model name to validate SmoothQuant on other models. `../act_scales/` provides the activation channel scales for OPT and BLOOM models.

In order to run this notebook, you need to install the following packages:

- smoothquant
- PyTorch
- Transformers
- Accelerate

In [1]:
%load_ext autoreload
%autoreload 2

import torch
from transformers.models.mamba.modeling_mamba import (
    MambaForCausalLM,
)
from transformers import AutoModelForCausalLM, AutoTokenizer
from smoothquant.smooth import smooth_lm
from smoothquant.fake_quant import W8A8Linear, quantize_mamba

  from .autonotebook import tqdm as notebook_tqdm


In this notebook, we simulate the 8-bit dynamic per-tensor weight and activation quantization with FP16, i.e., fake quantization. We have implemented the real 8-bit quantization with INT8 CUTLASS GEMM kernels for both PyTorch and FasterTransformer. Please stay tuned for the release.

The following is an evaluator to see the performance of the model. We use a toy dataset (the first 1000 examples in the validation set of the Lambada dataset) to evaluate the model. You can replace it with your own dataset. The conclusion should be the same.

**In this demo, we have simplified the evaluation by using the first 1,000 samples from the LAMBADA dataset's validation set. We employ the "Last Token Prediction Accuracy" as our evaluation metric. This approximate evaluation is intended for demonstration purposes, providing simple but meaningful comparisons of relative performance between methods. For a more strict assessment, we recommend using the [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) to obtain the "Last Word Prediction Accuracy" for the LAMBADA dataset, which is the reported metric in our paper.**

In [2]:
import torch.nn as nn
import tqdm
class TextGenerator:
    def __init__(self, tokenizer, device):
        self.tokenizer = tokenizer
        self.device = device

    def generate_text(self, model, input_text, max_length=50):
        input_ids = self.tokenizer.encode(input_text, return_tensors='pt').to(self.device)
        output = model.generate(inputs=input_ids,
                                max_length=max_length,
                                do_sample=True,
                                top_k=30,
                                pad_token_id=self.tokenizer.eos_token_id,
                                attention_mask=input_ids.new_ones(input_ids.shape))
        return self.tokenizer.decode(output[0], skip_special_tokens=True)

    def calculate_perplexity(self, model, text):
        # Encode the text
        encodings = self.tokenizer(text, return_tensors='pt').to(self.device)

        # Define input_ids and target_ids
        input_ids = encodings.input_ids
        target_ids = input_ids.clone()

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)

        # Loss calculation
        neg_log_likelihood = outputs.loss

        # Perplexity calculation
        ppl = torch.exp(neg_log_likelihood)

        return ppl

class LambadaEvaluator:
    def __init__(self, dataset, tokenizer, device):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.device = device

        # tokenize the dataset
        def tokenize_function(examples):
            example = self.tokenizer(examples["text"])
            return example

        self.dataset = self.dataset.map(tokenize_function, batched=True)
        self.dataset.set_format(type="torch", columns=["input_ids"])

    @torch.no_grad()
    def evaluate(self, model):
        model.eval()
        # The task is to predict the last word of the input.
        total, hit = 0, 0
        for batch in self.dataset:
            input_ids = batch["input_ids"].to(self.device).unsqueeze(0)
            label = input_ids[:, -1]
            outputs = model(input_ids)
            last_token_logits = outputs.logits[:, -2, :]
            pred = last_token_logits.argmax(dim=-1)
            total += label.size(0)
            hit += (pred == label).sum().item()
        acc = hit / total
        return acc
    
class WikitextEvaluator:
    def __init__(self, dataset, tokenizer, device, n_samples=40):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.device = device

        self.dataset = tokenizer(
            "\n\n".join(dataset["text"]), return_tensors="pt"
        ).input_ids.to(device)

        self.n_samples = n_samples

    @torch.no_grad()
    def evaluate(self, model):
        model.eval()
        nlls = []
        for i in tqdm.tqdm(range(self.n_samples), desc="Evaluating..."):
            batch = self.dataset[:, (i * 2048) : ((i + 1) * 2048)].to(model.device)
            with torch.no_grad():
                lm_logits = model(batch).logits
            shift_logits = lm_logits[:, :-1, :].contiguous().float()
            shift_labels = self.dataset[:, (i * 2048) : ((i + 1) * 2048)][:, 1:]
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )
            neg_log_likelihood = loss.float() * 2048
            nlls.append(neg_log_likelihood)

        return torch.exp(torch.stack(nlls).sum() / (self.n_samples * 2048))

In [3]:
from datasets import load_dataset

tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-2.8b-hf")
dataset = load_dataset("lambada", split="validation[:100]")
dataset = load_dataset("lambada", split="validation[:100]")
lambada_evaluator = LambadaEvaluator(dataset, tokenizer, "cuda")
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
wikitext_evaluator = WikitextEvaluator(dataset, tokenizer, "cuda")
text_generator = TextGenerator(tokenizer, 'cuda')

## FP16 Model Accuracy

Let's first check the performance of the original FP16 model.

In [4]:
model_fp16 = MambaForCausalLM.from_pretrained(
    "state-spaces/mamba-790m-hf", torch_dtype=torch.float16, device_map="auto"
)

The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py.


In [None]:
acc_fp16 = lambada_evaluator.evaluate(model_fp16)
print(f"Original model (fp16) accuracy: {acc_fp16}")

We then quantize the model to W8A8 and check the performance.

## Naive W8A8 Quantized Model Accuracy

In [6]:
model_w8a8 = quantize_mamba(model_fp16)
print(model_w8a8)

MambaForCausalLM(
  (backbone): MambaModel(
    (embeddings): Embedding(50280, 1536)
    (layers): ModuleList(
      (0-47): 48 x MambaBlock(
        (norm): MambaRMSNorm(1536, eps=1e-05)
        (mixer): MambaMixer(
          (conv1d): Conv1d(3072, 3072, kernel_size=(4,), stride=(1,), padding=(3,), groups=3072)
          (act): SiLU()
          (in_proj): W8A8Linear(1536, 6144, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
          (x_proj): W8A8Linear(3072, 128, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
          (dt_proj): W8A8Linear(96, 3072, bias=True, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
          (out_proj): W8A8Linear(3072, 1536, bias=False, weight_quant=per_tensor, act_quant=per_tensor, output_quant=per_tensor)
        )
      )
    )
    (norm_f): MambaRMSNorm(1536, eps=1e-05)
  )
  (lm_head): Linear(in_features=1536, out_features=50280, bias=False)
)


In [None]:
acc_w8a8 = lambada_evaluator.evaluate(model_w8a8)
print(f"Naive W8A8 quantized model accuracy: {acc_w8a8}")
perp_w8a8 = wikitext_evaluator.evaluate(model_w8a8)
print(f"Naive W8A8 quantized model perplexity: {perp_w8a8}")

Naive W8A8 quantized model accuracy: 0.07


Evaluating...:  82%|████████▎ | 33/40 [01:32<00:20,  2.87s/it]

We can see there is a significant accuracy drop. This is consistent with LLM.int8()'s finding: when the model size increases larger than 6.7B, systematic outliers will emerge in activations, which makes fully INT8 quantization impossible.