# SmoothQuant on Llama 2 7B

In [1]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"

MODEL_PATH = "/workspace/meta-llama/Llama-2-7b"

import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from transformers import LlamaTokenizer
from smooth import smooth_lm
from fake_quant import quantize_llama_like
import tqdm

class Evaluator:
    def __init__(self, dataset, tokenizer, device, n_samples=40):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.device = device

        self.dataset = tokenizer(
            "\n\n".join(dataset["text"]), return_tensors="pt"
        ).input_ids.to(device)

        self.n_samples = n_samples

    @torch.no_grad()
    def evaluate(self, model):
        model.eval()
        nlls = []
        for i in tqdm.tqdm(range(self.n_samples), desc="Evaluating..."):
            batch = self.dataset[:, (i * 2048) : ((i + 1) * 2048)].to(model.device)
            with torch.no_grad():
                lm_logits = model(batch).logits
            shift_logits = lm_logits[:, :-1, :].contiguous().float()
            shift_labels = self.dataset[:, (i * 2048) : ((i + 1) * 2048)][:, 1:]
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )
            neg_log_likelihood = loss.float() * 2048
            nlls.append(neg_log_likelihood)

        return torch.exp(torch.stack(nlls).sum() / (self.n_samples * 2048))
    
from datasets import load_dataset

tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
evaluator = Evaluator(dataset, tokenizer, "cuda")

## FP16 Model Perplexity

In [2]:
model_fp16 = LlamaForCausalLM.from_pretrained(
    MODEL_PATH, torch_dtype=torch.float16, device_map="cuda:0"
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
# ppl_fp16 = evaluator.evaluate(model_fp16)
# print(f"Baseline model perplexity: {ppl_fp16}")

In [None]:
import matplotlib.pyplot as plt
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP
import numpy as np

BIN_SIZE = 1000

for name, m in model_fp16.model.named_modules():
    if isinstance(m, LlamaAttention):
        # Her we simulate quantizing BMM inputs by quantizing the output of q_proj, k_proj, v_proj
        # print(m.q_proj)
        # print(m.k_proj)
        # print(m.v_proj)
        # print(m.o_proj)

        fig, axs = plt.subplots(3)
        hist = torch.histogram(m.q_proj.weight.float().cpu().detach(), bins=BIN_SIZE)
        axs[0].bar(hist[1].numpy(), hist[0].numpy(), hist[1][1]-hist[1][0])
        axs[0].title(f'{name}.q_proj')
        hist = torch.histogram(m.k_proj.weight.float().cpu().detach(), bins=BIN_SIZE)
        axs[1].bar(hist[1].numpy(), hist[0].numpy(), hist[1][1]-hist[1][0])
        axs[1].title(f'{name}.k_proj')
        hist = torch.histogram(m.v_proj.weight.float().cpu().detach(), bins=BIN_SIZE)
        axs[2].bar(hist[1].numpy(), hist[0].numpy(), hist[1][1]-hist[1][0])
        axs[2].title(f'{name}.v_proj')

        plt.suptitle(name)
        # Display the figure with subplots
        plt.tight_layout()
        plt.show()


    # if isinstance(m, LlamaMLP):

    #     plt.figure(0)
    #     hist = torch.histogram(m.gate_proj.weight.float().cpu().detach(), bins=BIN_SIZE)
    #     plt.bar(x=hist[1].numpy()[:-1], height=hist[0].numpy())
    #     plt.title('gate_proj')

    #     # # print(m.gate_proj)
    #     # hist = torch.histogram(m.gate_proj.weight.float().cpu().detach(), bins=BIN_SIZE)
    #     # axs[0].bar(x=hist[1].numpy()[:-1], height=hist[0].numpy())
    #     # axs[0].set_title('gate_proj')
    #     # # print(m.up_proj)
    #     # hist = torch.histogram(m.up_proj.weight.float().cpu().detach(), bins=BIN_SIZE)
    #     # axs[1].bar(x=hist[1].numpy()[:-1], height=hist[0].numpy())
    #     # axs[1].set_title('up_proj')
    #     # # print(m.down_proj)
    #     # hist = torch.histogram(m.down_proj.weight.float().cpu().detach(), bins=BIN_SIZE)
    #     # axs[2].bar(x=hist[1].numpy()[:-1], height=hist[0].numpy())
    #     # axs[2].set_title('down_proj')

    #     # plt.suptitle(name)
    #     # # Display the figure with subplots
    #     # plt.tight_layout()
    #     # plt.savefig(f'out/mlp/baseline/{name}.png')
    #     plt.show()

In [None]:
import matplotlib.pyplot as plt
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP
import numpy as np

BIN_SIZE = 1000

for name, m in model_fp16.model.named_modules():
    # if isinstance(m, LlamaAttention):
    #     # Her we simulate quantizing BMM inputs by quantizing the output of q_proj, k_proj, v_proj
    #     # print(m.q_proj)
    #     # print(m.k_proj)
    #     # print(m.v_proj)
    #     # print(m.o_proj)
    #     key_value_slicing = (m.num_key_value_heads * m.head_dim) // m.config.pretraining_tp
    #     query_slices = m.q_proj.weight.split(
    #         (m.num_heads * m.head_dim) // m.config.pretraining_tp, dim=0
    #     )
    #     key_slices = m.k_proj.weight.split(key_value_slicing, dim=0)
    #     value_slices = m.v_proj.weight.split(key_value_slicing, dim=0)
    #     fig, axs = plt.subplots(3)

    #     hist = torch.histc(query_slices[0].float(), bins=1000)
    #     axs[0].bar(range(1000), hist.cpu().detach().numpy())
    #     axs[0].set_title('Query channel #0')

    #     hist = torch.histc(key_slices[0].float(), bins=1000)
    #     axs[1].bar(range(1000), hist.cpu().detach().numpy())
    #     axs[1].set_title('Key channel #0')

    #     hist = torch.histc(value_slices[0].float(), bins=1000)
    #     axs[2].bar(range(1000), hist.cpu().detach().numpy())
    #     axs[2].set_title('Value channel #0')

    #     plt.suptitle(name)
    #     # Display the figure with subplots
    #     plt.tight_layout()
    #     plt.show()


    if isinstance(m, LlamaMLP):
        fig, axs = plt.subplots(3)

        # print(m.gate_proj)
        hist = torch.histc(m.gate_proj.weight.float(), bins=BIN_SIZE)
        axs[0].bar(range(BIN_SIZE), hist.cpu().detach().numpy())
        axs[0].set_title('gate_proj')
        # print(m.up_proj)
        hist = torch.histc(m.up_proj.weight.float(), bins=BIN_SIZE)
        axs[1].bar(range(BIN_SIZE), hist.cpu().detach().numpy())
        axs[1].set_title('up_proj')
        # print(m.down_proj)
        hist = torch.histc(m.down_proj.weight.float(), bins=BIN_SIZE)
        axs[2].bar(range(BIN_SIZE), hist.cpu().detach().numpy())
        axs[2].set_title('down_proj')

        plt.suptitle(name)
        # Display the figure with subplots
        plt.tight_layout()
        plt.savefig(f'out/mlp/baseline/{name}.png')
        plt.show()

In [None]:
del model_fp16
torch.cuda.empty_cache()

## SmoothQuant W8A8 Quantized Model Perplexity

In [None]:
model = LlamaForCausalLM.from_pretrained(
    MODEL_PATH, torch_dtype=torch.float16, device_map="cuda:1"
)
act_scales = torch.load("../act_scales/llama-2-7b.pt")
smooth_lm(model, act_scales, 0.85)
model_smoothquant_w8a8 = quantize_llama_like(model, weight_quant='per_tensor', act_quant='per_tensor', bits=(8,8))

In [None]:
# ppl_smoothquant_w8a8 = evaluator.evaluate(model_smoothquant_w8a8)
# print(f"SmoothQuant W8A8 quantized model perplexity: {ppl_smoothquant_w8a8}")

In [None]:
import matplotlib.pyplot as plt
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP
import numpy as np

BIN_SIZE = 1000

for name, m in model_smoothquant_w8a8.model.named_modules():
    # if isinstance(m, LlamaAttention):
    #     # Her we simulate quantizing BMM inputs by quantizing the output of q_proj, k_proj, v_proj
    #     # print(m.q_proj)
    #     # print(m.k_proj)
    #     # print(m.v_proj)
    #     # print(m.o_proj)
    #     key_value_slicing = (m.num_key_value_heads * m.head_dim) // m.config.pretraining_tp
    #     query_slices = m.q_proj.weight.split(
    #         (m.num_heads * m.head_dim) // m.config.pretraining_tp, dim=0
    #     )
    #     key_slices = m.k_proj.weight.split(key_value_slicing, dim=0)
    #     value_slices = m.v_proj.weight.split(key_value_slicing, dim=0)
    #     fig, axs = plt.subplots(3)

    #     hist = torch.histc(query_slices[0].float(), bins=1000)
    #     axs[0].bar(range(1000), hist.cpu().detach().numpy())
    #     axs[0].set_title('Query channel #0')

    #     hist = torch.histc(key_slices[0].float(), bins=1000)
    #     axs[1].bar(range(1000), hist.cpu().detach().numpy())
    #     axs[1].set_title('Key channel #0')

    #     hist = torch.histc(value_slices[0].float(), bins=1000)
    #     axs[2].bar(range(1000), hist.cpu().detach().numpy())
    #     axs[2].set_title('Value channel #0')

    #     plt.suptitle(name)
    #     # Display the figure with subplots
    #     plt.tight_layout()
    #     plt.show()


    if isinstance(m, LlamaMLP):
        fig, axs = plt.subplots(3)

        # print(m.gate_proj)
        hist = torch.histc(m.gate_proj.weight.float(), bins=BIN_SIZE)
        axs[0].bar(range(BIN_SIZE), hist.cpu().detach().numpy())
        axs[0].set_title('gate_proj')
        # print(m.up_proj)
        hist = torch.histc(m.up_proj.weight.float(), bins=BIN_SIZE)
        axs[1].bar(range(BIN_SIZE), hist.cpu().detach().numpy())
        axs[1].set_title('up_proj')
        # print(m.down_proj)
        hist = torch.histc(m.down_proj.weight.float(), bins=BIN_SIZE)
        axs[2].bar(range(BIN_SIZE), hist.cpu().detach().numpy())
        axs[2].set_title('down_proj')

        plt.suptitle(name)
        # Display the figure with subplots
        plt.tight_layout()
        plt.savefig(f'out/mlp/smooth_w8a8/{name}.png')
        plt.show()

In [None]:
del model_smoothquant_w8a8
torch.cuda.empty_cache()