In [1]:
import torch
import gc
def absmax_quantize(X):
    # Calculate scale
    scale = 127 / torch.max(torch.abs(X))

    # Quantize
    X_quant = (scale * X).round()

    # Dequantize
    X_dequant = X_quant / scale

    return X_quant.to(torch.int8), X_dequant

def absmax_perchannel_quantize(X, channel_axis=0):
    
    absmax = torch.amax(torch.abs(X), dim=channel_axis, keepdim=True)
    absmax[absmax == 0] = 1e-8  # avoid div by 0
    scale = 127 / absmax
    X_quant = (scale * X).round()
    X_dequant = X_quant / scale

    return X_quant.to(torch.int8), X_dequant

def tanh_quantize(X, k=3):
    mean = X.mean()
    std = X.std()

    # 3 sigma
    X_clipped = torch.clamp(X, mean - k * std, mean + k * std)
    X_normalized = (X_clipped - mean) / std

    X_tanh = torch.tanh(X_normalized)
    scale = 127 / torch.max(torch.abs(X_tanh))

    X_quant = (scale * X_tanh).round()
    X_dequant = X_quant / scale
    X_dequant = torch.arctanh(X_dequant.clamp(-0.99, 0.99)) * std + mean

    return X_quant.to(torch.int8), X_dequant

In [2]:
import torch.nn as nn
import tqdm
class TextGenerator:
    def __init__(self, tokenizer, device):
        self.tokenizer = tokenizer
        self.device = device

    def generate_text(self, model, input_text, max_length=50):
        input_ids = self.tokenizer.encode(input_text, return_tensors='pt').to(self.device)
        output = model.generate(inputs=input_ids,
                                max_length=max_length,
                                do_sample=True,
                                top_k=30,
                                pad_token_id=self.tokenizer.eos_token_id,
                                attention_mask=input_ids.new_ones(input_ids.shape))
        return self.tokenizer.decode(output[0], skip_special_tokens=True)

    def calculate_perplexity(self, model, text):
        # Encode the text
        encodings = self.tokenizer(text, return_tensors='pt').to(self.device)

        # Define input_ids and target_ids
        input_ids = encodings.input_ids
        target_ids = input_ids.clone()

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)

        # Loss calculation
        neg_log_likelihood = outputs.loss

        # Perplexity calculation
        ppl = torch.exp(neg_log_likelihood)

        return ppl

class LambadaEvaluator:
    def __init__(self, dataset, tokenizer, device):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.device = device

        # tokenize the dataset
        def tokenize_function(examples):
            example = self.tokenizer(examples["text"])
            return example

        self.dataset = self.dataset.map(tokenize_function, batched=True)
        self.dataset.set_format(type="torch", columns=["input_ids"])

    @torch.no_grad()
    def evaluate(self, model):
        model.eval()
        # The task is to predict the last word of the input.
        total, hit = 0, 0
        for batch in self.dataset:
            input_ids = batch["input_ids"].to(self.device).unsqueeze(0)
            label = input_ids[:, -1]
            outputs = model(input_ids)
            last_token_logits = outputs.logits[:, -2, :]
            pred = last_token_logits.argmax(dim=-1)
            total += label.size(0)
            hit += (pred == label).sum().item()
        acc = hit / total
        return acc
    
class WikitextEvaluator:
    def __init__(self, dataset, tokenizer, device, n_samples=40):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.device = device

        self.dataset = tokenizer(
            "\n\n".join(dataset["text"]), return_tensors="pt"
        ).input_ids.to(device)

        self.n_samples = n_samples

    @torch.no_grad()
    def evaluate(self, model):
        model.eval()
        nlls = []
        for i in tqdm.tqdm(range(self.n_samples), desc="Evaluating..."):
            batch = self.dataset[:, (i * 2048) : ((i + 1) * 2048)].to(model.device)
            with torch.no_grad():
                lm_logits = model(batch).logits
            shift_logits = lm_logits[:, :-1, :].contiguous().float()
            shift_labels = self.dataset[:, (i * 2048) : ((i + 1) * 2048)][:, 1:]
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )
            neg_log_likelihood = loss.float() * 2048
            nlls.append(neg_log_likelihood)

        return torch.exp(torch.stack(nlls).sum() / (self.n_samples * 2048))

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig
import torch
from datasets import load_dataset
torch.manual_seed(0)

# Set device to CPU for now
device = 'cuda'

# Load model and tokenizer
# model_id = 'state-spaces/mamba-790m-hf'
model_id = 'state-spaces/mamba-1.4b-hf'
# model_id = 'state-spaces/mamba-2.8b-hf'
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)

dataset = load_dataset("lambada", split="validation[:100]")
lambada_evaluator = LambadaEvaluator(dataset, tokenizer, "cuda")
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
wikitext_evaluator = WikitextEvaluator(dataset, tokenizer, "cuda")
text_generator = TextGenerator(tokenizer, 'cuda')

# Print model size
print(f"Model size: {model.get_memory_footprint():,} bytes")

total_memory = 0
for param in model.parameters():
    dtype_size = torch.finfo(param.dtype).bits // 8 # using int 8 quantization
    total_memory += param.numel() * dtype_size

print(f"Calculated model size: {total_memory:,} bytes")

  from .autonotebook import tqdm as notebook_tqdm
The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py.
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 11.02it/s]


Model size: 5,488,713,728 bytes
Calculated model size: 5,488,713,728 bytes


# Understanding the Mamba Architecture

In [4]:
print(model)

MambaForCausalLM(
  (backbone): MambaModel(
    (embeddings): Embedding(50280, 2048)
    (layers): ModuleList(
      (0-47): 48 x MambaBlock(
        (norm): MambaRMSNorm(2048, eps=1e-05)
        (mixer): MambaMixer(
          (conv1d): Conv1d(4096, 4096, kernel_size=(4,), stride=(1,), padding=(3,), groups=4096)
          (act): SiLU()
          (in_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (x_proj): Linear(in_features=4096, out_features=160, bias=False)
          (dt_proj): Linear(in_features=128, out_features=4096, bias=True)
          (out_proj): Linear(in_features=4096, out_features=2048, bias=False)
        )
      )
    )
    (norm_f): MambaRMSNorm(2048, eps=1e-05)
  )
  (lm_head): Linear(in_features=2048, out_features=50280, bias=False)
)


### Embeddings
The model first maps our input tokens (which has a vocabulary size of 50,280) to an embedding dimension of 2048. This will act as our input sentence sequence to the Mamba Model

### 48 x Mamba Blocks
Recall the state-space model update system of equations:   
$\frac{dx(t)}{dt} = Ax(t) + Bu(t)$      
$y(t) = Cx(t) + Du(t)$    
where $x(t) \in \mathbb{R}^H$ is the hidden state vector, $u(t) \in \mathbb{R}^N$ is the input vector, and $y(t) \in \mathbb{M}^N$ is the output vector

We may discretize this as    
$x_{t+1} = Ax_t + Bu_t$   
$y_t = Cx_t + Du_t$

These matrices should have the following dimensions: $A \in \mathbb{R}^{H \times H}$, $B \in \mathbb{R}^{H \times N}$, $C \in \mathbb{R}^{M \times H}$, and $D \in \mathbb{R}^{M \times N}$.

Moving on to the implementation of the model: (here, we consult the Mamba's source code by the authors to understand how the model works)

`in_proj` projects the input (a.k.a. `d_model=2048`) to a combined inner state dimension (a.k.a. `2 * d_inner=8192`). The idea behind projecting our input to a higher dimensional space is to allow for richer feature interactions and redundancies. It is important to note that the true state dimension `d_inner` is 4096. And that `in_proj`actually projects the input to the concatenation of the the state vector $x$ itself and a residual stream $z$ (hence the multiplication by 2). 

`conv1d` performs a simple 1d convolution on `d_inner`, which aims to blend local temporal dependencies in the input vector. 

`x_proj` projects the input vector we obtained into a space that is the concatenation of dimensions (`dt_rank`, `d_state`, `d_state`). What this projection does is that it maps the current input vector to `dt`, `B`, and `C`. `dt` here has dimension`dt_rank`, which refers to the rank used in the `dt_proj` layer. The intuition behind `dt_rank` is that it controls how complex the adjustments should be when it comes to modeling temporal dynamics of the sequence. This makes sense because for a higher `dt_rank`, we have more parameters in the linear layer that projects back up to `dt_inner` in the `dt_proj` layer. A similar intuition applies to `B`, `C` and their dimension `d_state`. Instead of using the original SSM formulation, where $B \in \mathbb{R}^{H \times N = 4096 \times 2048}$ and $C \in \mathbb{R}^{M \times H = 2048 \times 4096}$, $B$ (which determines how inputs affect state transitions), and $C$ (which controls how states are mapped to outputs) instead operates on a latent representation of the state vector, and this latent representation is obtained from the $A$ matrix, which, instead of being a $\mathbb{R}^{H \times H}$ matrix, maps from H=4096 to the latent space `d_state`. This is what makes mamba different from traditional SSM models, as $B$ and $C$ matrices can "selectively scan" for the context that can best predict the next token, depending on what the input is. The purpose of `x_proj` is therefore to generate the parameters in order to dynamically adjust the state vector.

`dt_proj` projects `dt_rank` to `d_inner`. After obtaining the optimal parameters for state updates, `dt_proj` performs the status updates to update the state vector.

In [5]:
# Extract weights of the first layer
weights = model.backbone.layers[0].mixer.in_proj.weight
print("Original weights:")
print(weights)

# Quantize layer using absmax quantization
weights_abs_quant, _ = absmax_quantize(weights)
print("\nAbsmax quantized weights:")
print(weights_abs_quant)     

Original weights:
Parameter containing:
tensor([[-0.0093,  0.0195,  0.0045,  ..., -0.0280,  0.0205,  0.0303],
        [-0.0455, -0.0419, -0.0181,  ...,  0.0200, -0.0017, -0.0090],
        [ 0.0152, -0.0286,  0.0063,  ..., -0.0122, -0.0423, -0.0070],
        ...,
        [ 0.0288, -0.0310,  0.0342,  ...,  0.0013, -0.0933, -0.0034],
        [ 0.0006,  0.0376, -0.0062,  ...,  0.0058,  0.0233, -0.0073],
        [ 0.0035, -0.0039, -0.0507,  ..., -0.0188, -0.0253,  0.0127]],
       device='cuda:0', requires_grad=True)

Absmax quantized weights:
tensor([[ -1,   3,   1,  ...,  -4,   3,   4],
        [ -6,  -6,  -3,  ...,   3,   0,  -1],
        [  2,  -4,   1,  ...,  -2,  -6,  -1],
        ...,
        [  4,  -4,   5,  ...,   0, -13,   0],
        [  0,   5,  -1,  ...,   1,   3,  -1],
        [  0,  -1,  -7,  ...,  -3,  -4,   2]], device='cuda:0',
       dtype=torch.int8)


In [None]:
print("\nLAMBADA Accuracy Evaluation")
acc_original = lambada_evaluator.evaluate(model)
print(f"accuracy on LAMBADA: {acc_original}")

print("\nPerplexity Evaluation on WikiText")
pp_wikitext = wikitext_evaluator.evaluate(model)
print(f'perplexity on wikitext: {pp_wikitext}')

print("\nPerplexity Evaluation on custom text")
original_text = text_generator.generate_text(model, "I have a dream")
print(f"output text:\n{original_text}")
ppl = text_generator.calculate_perplexity(model, original_text)
print(f"original model perplexity: {ppl.item():.2f}")


# Quantized Models

In [7]:
import numpy as np
from copy import deepcopy

# cur_model = "w8_all"
# cur_model = "w8_pc_all"
# cur_model = "w8_tanh_all"
cur_model = "w8_inout"
# cur_model = "w8_pc_inout"
# cur_model = "w8_tanh_inout"

### W8_All: Quantizing all weights of the Mamba Model naively

In [8]:
import numpy as np
from copy import deepcopy

if (cur_model == "w8_all"):
    model_abs = deepcopy(model)
    model.cpu()
    del model
    gc.collect()
    torch.cuda.empty_cache()

    total_memory = 0
    for param in model_abs.parameters():
        _, dequantized = absmax_quantize(param.data)
        param.data = dequantized
        dtype_size = 8 // 8
        total_memory += param.numel() * dtype_size

    print(f"Calculated model size: {total_memory:,} bytes")
else:
    print("skipped")

skipped


In [9]:
if (cur_model == "w8_all"):
    print("\nLAMBADA Accuracy Evaluation")
    acc_original = lambada_evaluator.evaluate(model_abs)
    print(f"accuracy on LAMBADA: {acc_original}")

    print("\nPerplexity Evaluation on WikiText")
    pp_wikitext = wikitext_evaluator.evaluate(model_abs)
    print(f'perplexity on wikitext: {pp_wikitext}')

    print("\nPerplexity Evaluation on custom text")
    original_text = text_generator.generate_text(model_abs, "I have a dream")
    print(f"output text:\n{original_text}")
    ppl = text_generator.calculate_perplexity(model_abs, original_text)
    print(f"quantized model perplexity: {ppl.item():.2f}")
else:
    print("skipped")

skipped


### W8_PC_All: Naive per-channel quantization

In [10]:
import numpy as np
from copy import deepcopy

if (cur_model == "w8_pc_all"):
    model_pcabs = deepcopy(model)
    model.cpu()
    del model
    gc.collect()
    torch.cuda.empty_cache()

    total_memory = 0
    for param in model_pcabs.parameters():
        _, dequantized = absmax_perchannel_quantize(param.data)
        param.data = dequantized
        dtype_size = 8 // 8
        total_memory += param.numel() * dtype_size

    print(f"Calculated model size: {total_memory:,} bytes")
else:
    print("skipped")

skipped


In [11]:
if (cur_model == "w8_pc_all"):
    print("\nLAMBADA Accuracy Evaluation")
    acc_original = lambada_evaluator.evaluate(model_pcabs)
    print(f"accuracy on LAMBADA: {acc_original}")

    print("\nPerplexity Evaluation on WikiText")
    pp_wikitext = wikitext_evaluator.evaluate(model_pcabs)
    print(f'perplexity on wikitext: {pp_wikitext}')

    print("\nPerplexity Evaluation on custom text")
    original_text = text_generator.generate_text(model_pcabs, "I have a dream")
    print(f"output text:\n{original_text}")
    ppl = text_generator.calculate_perplexity(model_pcabs, original_text)
    print(f"quantized model perplexity: {ppl.item():.2f}")
else:
    print("skipped")

skipped


### W8_Tanh_All: Using Tanh-Based function due to the normal distribution of weights to reduce ambiguity and quantization error around mean

In [12]:
import numpy as np
from copy import deepcopy

if (cur_model == "w8_tanh_all"):
    model_tanh = deepcopy(model)

    total_memory = 0
    for param in model_tanh.parameters():
        _, dequantized = tanh_quantize(param.data)
        param.data = dequantized
        dtype_size = 8 // 8
        total_memory += param.numel() * dtype_size

    print(f"Calculated model size: {total_memory:,} bytes")
else:
    print("skipped")

skipped


In [13]:
# Why do we do this? let's visuzlize

if (cur_model == "w8_tanh_all"):
    i = 47
    import matplotlib.pyplot as plt
    weights_og_inproj = model.backbone.layers[i].mixer.in_proj.weight.cpu().detach().numpy().flatten()
    weights_og_outproj = model.backbone.layers[i].mixer.out_proj.weight.cpu().detach().numpy().flatten()
    weights_tanh_inproj = model_tanh.backbone.layers[i].mixer.in_proj.weight.cpu().detach().numpy().flatten()
    weights_tanh_outproj = model_tanh.backbone.layers[i].mixer.out_proj.weight.cpu().detach().numpy().flatten()
    weights_og_xproj = model.backbone.layers[i].mixer.x_proj.weight.cpu().detach().numpy().flatten()
    weights_og_dtproj = model.backbone.layers[i].mixer.dt_proj.weight.cpu().detach().numpy().flatten()
    weights_tanh_xproj = model_tanh.backbone.layers[i].mixer.x_proj.weight.cpu().detach().numpy().flatten()
    weights_tanh_dtproj = model_tanh.backbone.layers[i].mixer.dt_proj.weight.cpu().detach().numpy().flatten()

    fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(9,6), sharex=True)

    axs[0][0].hist(weights_og_inproj, bins=150, alpha=0.5, label='in_proj original weights', color='blue', range=(-0.2, 0.2))
    axs[0][0].hist(weights_tanh_inproj, bins=150, alpha=0.5, label='in_proj quantized weights', color='red', range=(-0.2, 0.2))
    axs[0][1].hist(weights_og_outproj, bins=150, alpha=0.5, label='out_proj original weights', color='blue', range=(-0.2, 0.2))
    axs[0][1].hist(weights_tanh_outproj, bins=150, alpha=0.5, label='out_proj quantized weights', color='green', range=(-0.2, 0.2))
    axs[1][0].hist(weights_og_xproj, bins=150, alpha=0.5, label='x_proj original weights', color='blue', range=(-0.2, 0.2))
    axs[1][0].hist(weights_tanh_xproj, bins=150, alpha=0.5, label='x_proj quantized weights', color='red', range=(-0.2, 0.2))
    axs[1][1].hist(weights_og_dtproj, bins=150, alpha=0.5, label='dt_proj original weights', color='blue', range=(-0.2, 0.2))
    axs[1][1].hist(weights_tanh_dtproj, bins=150, alpha=0.5, label='dt_proj quantized weights', color='green', range=(-0.2, 0.2))

    # Add grid
    for row in axs:
        for ax in row:
            ax.grid(True, linestyle='--', alpha=0.6)
            ax.set_xlabel('Weights', fontsize=14)
            ax.set_ylabel('Count', fontsize=14)
            ax.legend()

    axs[0][0].set_title('in_proj layer', fontsize=16)
    axs[0][1].set_title('out_proj layer', fontsize=16)
    axs[1][0].set_title('x_proj layer', fontsize=16)
    axs[1][1].set_title('dt_proj layer', fontsize=16)

    plt.tight_layout()
    plt.show()

In [14]:
# What does sigmoid quantization do?

if (cur_model == "w8_tanh_all"):
    quantized_in_proj = tanh_quantize(model.backbone.layers[0].mixer.in_proj.weight.cpu())[0].detach().numpy().flatten()
    quantized_out_proj = tanh_quantize(model.backbone.layers[0].mixer.out_proj.weight.cpu())[0].detach().numpy().flatten()

    fig, axs = plt.subplots(2, figsize=(6,4), sharex=True)

    axs[0].hist(quantized_in_proj, bins=150, alpha=0.5, label='in_proj original weights', color='blue', range=(-192, 192))
    axs[1].hist(quantized_out_proj, bins=150, alpha=0.5, label='out_proj original weights', color='blue', range=(-192, 192))

    # Add grid
    for ax in axs:
        ax.grid(True, linestyle='--', alpha=0.6)

    axs[0].legend()
    axs[1].legend()
    axs[0].set_title('Comparison of Original and Absmax Quantized Weights', fontsize=16)
    axs[1].set_title('Comparison of Original and Zeropoint Quantized Weights', fontsize=16)

    for ax in axs:
        ax.set_xlabel('Weights', fontsize=14)
        ax.set_ylabel('Count', fontsize=14)

    plt.tight_layout()
    plt.show()

In [15]:
if (cur_model == "w8_tanh_all"):
    print("\nLAMBADA Accuracy Evaluation")
    acc_original = lambada_evaluator.evaluate(model_tanh)
    print(f"accuracy on LAMBADA: {acc_original}")

    print("\nPerplexity Evaluation on WikiText")
    pp_wikitext = wikitext_evaluator.evaluate(model_tanh)
    print(f'perplexity on wikitext: {pp_wikitext}')

    print("\nPerplexity Evaluation on custom text")
    original_text = text_generator.generate_text(model_tanh, "I have a dream")
    print(f"output text:\n{original_text}")
    ppl = text_generator.calculate_perplexity(model_tanh, original_text)
    print(f"quantized model perplexity: {ppl.item():.2f}")
else:
    print("skipped")

skipped


### W8_InOut: Quantizing only in_proj and out_proj
Why only in_proj and out_proj? These two projections expand and compress our inputs to and back from a richer, higher dimensional space. This is because higher dimensional spaces are used to model lots of feature interactions among inputs. This makes them not as sensitive to quantization errors due to potential redundancies in the high-D representation.

Furthermore, in_proj and out_proj account for the majority of parameters in the model, as they are both used for projecting to and from the input embedding (2048) and some higher-dimensional state space (4096 and 8192). In contrast, x_proj and dt_proj, which enables state updates and are the backbone of the model, only operate on a 16-dimensional low-rank state space. 

In [16]:
import numpy as np
from copy import deepcopy

if (cur_model == "w8_inout"):
    model_w8inout = deepcopy(model)
    model.cpu()
    del model
    gc.collect()
    torch.cuda.empty_cache()
    
    total_memory = 0
    num_in_out_params = 0
    num_total_params = 0
    for name, param in model_w8inout.named_parameters():
        # print(name)
        if "in_proj" in name or "out_proj" in name:
            _, dequantized = absmax_quantize(param.data)
            param.data = dequantized
            dtype_size = 8 // 8
            total_memory += param.numel() * dtype_size
            num_in_out_params += param.numel()
        else:
            dtype_size = torch.finfo(param.dtype).bits // 8
            total_memory += param.numel() * dtype_size
        num_total_params += param.numel()

    print(f"Calculated model size: {total_memory:,} bytes")
    print(f"Proportion of in_proj and out_proj params: {num_in_out_params/num_total_params}")
else:
    print("skipped")

Calculated model size: 1,864,835,072 bytes
Proportion of in_proj and out_proj params: 0.8803225031305549


In [17]:
if (cur_model == "w8_inout"):
    print("\nLAMBADA Accuracy Evaluation")
    acc_original = lambada_evaluator.evaluate(model_w8inout)
    print(f"accuracy on LAMBADA: {acc_original}")

    print("\nPerplexity Evaluation on WikiText")
    pp_wikitext = wikitext_evaluator.evaluate(model_w8inout)
    print(f'perplexity on wikitext: {pp_wikitext}')

    print("\nPerplexity Evaluation on custom text")
    original_text = text_generator.generate_text(model_w8inout, "I have a dream")
    print(f"output text:\n{original_text}")
    ppl = text_generator.calculate_perplexity(model_w8inout, original_text)
    print(f"quantized model perplexity: {ppl.item():.2f}")
else:
    print("skipped")


LAMBADA Accuracy Evaluation
accuracy on LAMBADA: 0.82

Perplexity Evaluation on WikiText


Evaluating...: 100%|██████████| 40/40 [01:48<00:00,  2.71s/it]


perplexity on wikitext: 10.558557510375977

Perplexity Evaluation on custom text
output text:
I have a dream, I have a vision." "So I'll ask the question again, as a prefect you're expected to attend the council every day." "That's very good." "And, as you will see, your mother and
quantized model perplexity: 14.90


### W8_PC_InOut: Per-Channel Quantization of only in_proj and out_proj layers

In [18]:
import numpy as np
from copy import deepcopy

if (cur_model == "w8_pc_inout"):
    model_w8_pc_inout = deepcopy(model)
    model.cpu()
    del model
    gc.collect()
    torch.cuda.empty_cache()
    
    total_memory = 0

    for name, param in model_w8_pc_inout.named_parameters():
        # print(name)
        if "in_proj" in name or "out_proj" in name:
            _, dequantized = absmax_perchannel_quantize(param.data)
            param.data = dequantized
            dtype_size = 8 // 8
            total_memory += param.numel() * dtype_size
        else:
            dtype_size = torch.finfo(param.dtype).bits // 8
            total_memory += param.numel() * dtype_size

    print(f"Calculated model size: {total_memory:,} bytes")
else:
    print("skipped")

skipped


In [19]:
if (cur_model == "w8_pc_inout"):
    print("\nLAMBADA Accuracy Evaluation")
    acc_original = lambada_evaluator.evaluate(model_w8_pc_inout)
    print(f"accuracy on LAMBADA: {acc_original}")


    print("\nPerplexity Evaluation on WikiText")
    pp_wikitext = wikitext_evaluator.evaluate(model_w8_pc_inout)
    print(f'perplexity on wikitext: {pp_wikitext}')

    print("\nPerplexity Evaluation on custom text")
    original_text = text_generator.generate_text(model_w8_pc_inout, "I have a dream")
    print(f"output text:\n{original_text}")
    ppl = text_generator.calculate_perplexity(model_w8_pc_inout, original_text)
    print(f"quantized model perplexity: {ppl.item():.2f}")
else:
    print("skipped")

skipped


### W8_Tanh_InOut

In [20]:
import numpy as np
from copy import deepcopy

if (cur_model == "w8_tanh_inout"):
    model_w8_tanh_inout = deepcopy(model)

    total_memory = 0
    
    for name, param in model_w8_tanh_inout.named_parameters():
        # print(name)
        if "in_proj" in name or "out_proj" in name:
            _, dequantized = tanh_quantize(param.data)
            param.data = dequantized
            dtype_size = 8 // 8
            total_memory += param.numel() * dtype_size
        else:
            dtype_size = torch.finfo(param.dtype).bits // 8
            total_memory += param.numel() * dtype_size

    print(f"Calculated model size: {total_memory:,} bytes")
else:
    print("skipped")

skipped


In [21]:
if (cur_model == "w8_tanh_inout"):
    print("\nLAMBADA Accuracy Evaluation")
    acc_original = lambada_evaluator.evaluate(model_w8_tanh_inout)
    print(f"accuracy on LAMBADA: {acc_original}")

    print("\nPerplexity Evaluation on WikiText")
    pp_wikitext = wikitext_evaluator.evaluate(model_w8_tanh_inout)
    print(f'perplexity on wikitext: {pp_wikitext}')

    print("\nPerplexity Evaluation on custom text")
    original_text = text_generator.generate_text(model_w8_tanh_inout, "I have a dream")
    print(f"output text:\n{original_text}")
    ppl = text_generator.calculate_perplexity(model_w8_tanh_inout, original_text)
    print(f"quantized model perplexity: {ppl.item():.2f}")
else:
    print("skipped")

skipped
