In [None]:
# install this at the beginning to get all dependency
!pip install -q -U datasets
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git

In [None]:
# Define and import the model from huggingface
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-370m-hf")
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-370m-hf")
model.to('cuda')

input_ids = tokenizer("The weather is good.", return_tensors="pt")["input_ids"].to('cuda')

out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))

In [None]:
#save the original mamba model(we need this for the quantization)
torch.save(model.state_dict(), "mamba_model.pt")

We only Quantize the in_proj and dt_proj using Linear8bit

In [None]:
import torch
import torch.nn as nn
import bitsandbytes as bnb

for i, block in enumerate(model.backbone.layers):
  #inner_project layer
  in_proj_layer = block.mixer.in_proj

  in_proj_layer_in_features = in_proj_layer.in_features
  in_proj_layer_out_features = in_proj_layer.out_features
  in_proj_layer_bias = in_proj_layer.bias is not None

  # Create a new 8-bit precision in_proj layer
  # Make sure to set has_fp16_weights=False for inference-focused quantization
  new_in_proj_layer = bnb.nn.Linear8bitLt(in_proj_layer_in_features, in_proj_layer_out_features,bias=in_proj_layer_bias, has_fp16_weights=False)

  # Replace the existing in_proj layer with the new one
  block.mixer.in_proj = new_in_proj_layer

  #dt_project layer
  dt_proj_layer = block.mixer.dt_proj

  dt_proj_layer_in_features = dt_proj_layer.in_features
  dt_proj_layer_out_features = dt_proj_layer.out_features
  dt_proj_layer_bias = dt_proj_layer.bias is not None

  new_dt_proj_layer = bnb.nn.Linear8bitLt(dt_proj_layer_in_features, dt_proj_layer_out_features,bias= dt_proj_layer_bias, has_fp16_weights=False)

  block.mixer.dt_proj = new_dt_proj_layer


# To load the state_dict back into the model (for inference or further adjustments):
model.load_state_dict(torch.load("mamba_model.pt"))

quant_model = model.to('cuda:0') # This also triggers the internal quantization process in bitsandbytes

In [None]:
quant_model

In [None]:
#!pip install datasets
from datasets import load_dataset

train = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
encodings = tokenizer("\n\n".join(train["text"]), return_tensors="pt")

Downloading readme:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/733k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [None]:
import torch
from tqdm import tqdm

def perplexity_score(model, data, encodings, num_iterations):
  device = "cuda"
  max_length = 1024
  stride = 512
  seq_len = encodings.input_ids.size(1)
  model = model.to(device)
  nlls = []
  prev_end_loc = 0
  count = 0
  for begin_loc in tqdm(range(0, seq_len, stride)):

      end_loc = min(begin_loc + max_length, seq_len)
      trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
      input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
      target_ids = input_ids.clone()
      target_ids[:, :-trg_len] = -100

      with torch.no_grad():
          outputs = model(input_ids, labels=target_ids)
          # loss is calculated using CrossEntropyLoss which averages over valid labels
          # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
          # to the left by 1.
          neg_log_likelihood = outputs.loss

      nlls.append(neg_log_likelihood)
      count+=1
      if count == num_iterations:
          break
      prev_end_loc = end_loc
      if end_loc == seq_len:
          break

  ppl = torch.exp(torch.stack(nlls).mean())
  return ppl

We calculate the perplexity score and use it a a default perplexity score to compare it against the profiling apporach

In [None]:
num_iterations = 5
default_perplexity_score = perplexity_score(quant_model, train, encodings, num_iterations)
print(f"Default Perplexity Score: {default_perplexity_score}")

  0%|          | 4/4750 [00:25<8:29:53,  6.45s/it]

Default Perplexity Score: 864.5032348632812





# Profiling-Based Approach

This is for out_project layer

In [None]:
mapping = set()
num_block = len(quant_model.backbone.layers)
current_perplexity_score = default_perplexity_score
for i in reversed(range(num_block)):
  print("BLOCK",i)
  quant_model = MambaForCausalLM.from_pretrained("state-spaces/mamba-370m-hf") #load the model
  torch.save(quant_model.state_dict(), "mamba_model_temp.pt") #save the model
  for j, block in enumerate(quant_model.backbone.layers):
    #inner_project layer
    in_proj_layer = block.mixer.in_proj

    in_proj_layer_in_features = in_proj_layer.in_features
    in_proj_layer_out_features = in_proj_layer.out_features
    in_proj_layer_bias = in_proj_layer.bias is not None
    new_in_proj_layer = bnb.nn.Linear8bitLt(in_proj_layer_in_features, in_proj_layer_out_features,bias=in_proj_layer_bias, has_fp16_weights=False)

    block.mixer.in_proj = new_in_proj_layer
    dt_proj_layer = block.mixer.dt_proj

    dt_proj_layer_in_features = dt_proj_layer.in_features
    dt_proj_layer_out_features = dt_proj_layer.out_features
    dt_proj_layer_bias = dt_proj_layer.bias is not None

    new_dt_proj_layer = bnb.nn.Linear8bitLt(dt_proj_layer_in_features, dt_proj_layer_out_features,bias= dt_proj_layer_bias, has_fp16_weights=False)

    block.mixer.dt_proj = new_dt_proj_layer

    if j in mapping or j == i:
      #out_project layer
      out_proj = block.mixer.out_proj

      out_proj_layer_in_features = out_proj.in_features
      out_proj_layer_out_features = out_proj.out_features
      out_proj_layer_bias = out_proj.bias is not None

      new_out_proj_layer = bnb.nn.Linear8bitLt(out_proj_layer_in_features, out_proj_layer_out_features,bias= out_proj_layer_bias, has_fp16_weights=False)
      block.mixer.out_proj = new_out_proj_layer

  # To load the state_dict back into the model (for inference or further adjustments):
  quant_model.load_state_dict(torch.load("mamba_model_temp.pt"))
  # If your deployment environment supports it, move the model to the appropriate device

  quant_model = quant_model.to('cuda:0') # This also triggers the internal quantization process in bitsandbytes

  quant_perplexity_score = perplexity_score(quant_model, train, encodings, num_iterations)
  print(f"Quant Perplexity Score: {quant_perplexity_score}")
  print(f"Current Perplexity Score: {current_perplexity_score}")
  if quant_perplexity_score < 1.2 * current_perplexity_score:
    current_perplexity_score = quant_perplexity_score
    print(i,"block was quantized")
    mapping.add(i)


This is for the x_proj layer

In [None]:
mapping = set()
num_block = len(quant_model.backbone.layers)
current_perplexity_score = default_perplexity_score
for i in reversed(range(num_block)):
  print("BLOCK",i)
  quant_model = MambaForCausalLM.from_pretrained("state-spaces/mamba-370m-hf") #load the model
  torch.save(quant_model.state_dict(), "mamba_model_temp.pt") #save the model
  for j, block in enumerate(quant_model.backbone.layers):
    #inner_project layer
    in_proj_layer = block.mixer.in_proj

    in_proj_layer_in_features = in_proj_layer.in_features
    in_proj_layer_out_features = in_proj_layer.out_features
    in_proj_layer_bias = in_proj_layer.bias is not None
    new_in_proj_layer = bnb.nn.Linear8bitLt(in_proj_layer_in_features, in_proj_layer_out_features,bias=in_proj_layer_bias, has_fp16_weights=False)

    block.mixer.in_proj = new_in_proj_layer
    dt_proj_layer = block.mixer.dt_proj

    dt_proj_layer_in_features = dt_proj_layer.in_features
    dt_proj_layer_out_features = dt_proj_layer.out_features
    dt_proj_layer_bias = dt_proj_layer.bias is not None

    new_dt_proj_layer = bnb.nn.Linear8bitLt(dt_proj_layer_in_features, dt_proj_layer_out_features,bias= dt_proj_layer_bias, has_fp16_weights=False)

    block.mixer.dt_proj = new_dt_proj_layer

    if j in mapping or j == i:
      #x_project layer
      x_proj = block.mixer.x_proj

      x_proj_layer_in_features = x_proj.in_features
      x_proj_layer_out_features = x_proj.out_features
      x_proj_layer_bias = x_proj.bias is not None

      new_x_proj_layer = bnb.nn.Linear8bitLt(x_proj_layer_in_features, x_proj_layer_out_features,bias= x_proj_layer_bias, has_fp16_weights=False)
      block.mixer.x_proj = new_x_proj_layer

  # To load the state_dict back into the model (for inference or further adjustments):
  quant_model.load_state_dict(torch.load("mamba_model_temp.pt"))
  # If your deployment environment supports it, move the model to the appropriate device

  quant_model = quant_model.to('cuda:0') # This also triggers the internal quantization process in bitsandbytes

  quant_perplexity_score = perplexity_score(quant_model, train, encodings, num_iterations)
  print(f"Quant Perplexity Score: {quant_perplexity_score}")
  print(f"Current Perplexity Score: {current_perplexity_score}")
  if quant_perplexity_score < 1.2 * current_perplexity_score:
    current_perplexity_score = quant_perplexity_score
    print(i,"block was quantized")
    mapping.add(i)
