### Use this Colab notebook to do outlier counting for hidden states. To use this notebook, run each cell sequentially, following any instructions that exist




In [None]:
# install dependencies
!pip install -q -U datasets
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U bitsandbytes

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/542.0 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m388.9/388.9 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m297.6/297.6 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for peft (pyproject.toml) ... [?25l[?25hdone


In [None]:
# load model

from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-370m-hf")
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-370m-hf", quantization_config=None, output_hidden_states=True, device_map="auto")


In [None]:
print(model)

In [None]:
#save the original mamba model
torch.save(model.state_dict(), "mamba_model.pt")

## Manually changing each layer to Linear8bit

This loop replaces each linear layer with a int8 linear layer. Comment out any layer types you don't want to replace

In [None]:
import torch
import torch.nn as nn
import bitsandbytes as bnb

threshold=6.0

# This loop replaces each linear layer with a int8 linear layer. Comment out any layer types you don't want to replace

# Assuming 'model' is your pre-trained MambaForCausalLM model
# This code modify the og mamba model, where it replaces the Linear Layer to Linear8bit
for i, block in enumerate(model.backbone.layers):


    # inner_project layer
    in_proj_layer = block.mixer.in_proj

    in_proj_layer_in_features = in_proj_layer.in_features
    in_proj_layer_out_features = in_proj_layer.out_features
    in_proj_layer_bias = in_proj_layer.bias is not None

    # Create a new 8-bit precision in_proj layer
    # Make sure to set has_fp16_weights=False for inference-focused quantization
    new_in_proj_layer = bnb.nn.Linear8bitLt(in_proj_layer_in_features, in_proj_layer_out_features,bias=in_proj_layer_bias, has_fp16_weights=False, threshold=threshold)

    # Replace the existing in_proj layer with the new one
    block.mixer.in_proj = new_in_proj_layer


    # x_project layer
    x_proj = block.mixer.x_proj

    x_proj_layer_in_features = x_proj.in_features
    x_proj_layer_out_features = x_proj.out_features
    x_proj_layer_bias = x_proj.bias is not None

    new_x_proj_layer = bnb.nn.Linear4bit(x_proj_layer_in_features, x_proj_layer_out_features,bias= x_proj_layer_bias)

    # Replace the existing in_proj layer with the new one
    block.mixer.x_proj = new_x_proj_layer


    # dt_project layer
    dt_proj_layer = block.mixer.dt_proj

    dt_proj_layer_in_features = dt_proj_layer.in_features
    dt_proj_layer_out_features = dt_proj_layer.out_features
    dt_proj_layer_bias = dt_proj_layer.bias is not None

    new_dt_proj_layer = bnb.nn.Linear8bitLt(dt_proj_layer_in_features, dt_proj_layer_out_features,bias= dt_proj_layer_bias, has_fp16_weights=False, threshold=threshold)

    block.mixer.dt_proj = new_dt_proj_layer



    #out_project layer

    out_proj = block.mixer.out_proj

    out_proj_layer_in_features = out_proj.in_features
    out_proj_layer_out_features = out_proj.out_features
    out_proj_layer_bias = out_proj.bias is not None

    new_out_proj_layer = bnb.nn.Linear4bit(out_proj_layer_in_features, out_proj_layer_out_features,bias= out_proj_layer_bias)
    block.mixer.out_proj = new_out_proj_layer
    pass



# To load the state_dict back into the model (for inference or further adjustments):
model.load_state_dict(torch.load("mamba_model.pt"))

# If your deployment environment supports it, move the model to the appropriate device
# For example, using CUDA device 0
bit_model = model.to('cuda:0') # This also triggers the internal quantization process in bitsandbytes


In [None]:
bit_model

MambaForCausalLM(
  (backbone): MambaModel(
    (embeddings): Embedding(50280, 1024)
    (layers): ModuleList(
      (0-47): 48 x MambaBlock(
        (norm): MambaRMSNorm()
        (mixer): MambaMixer(
          (conv1d): Conv1d(2048, 2048, kernel_size=(4,), stride=(1,), padding=(3,), groups=2048)
          (act): SiLU()
          (in_proj): Linear8bitLt(in_features=1024, out_features=4096, bias=False)
          (x_proj): Linear4bit(in_features=2048, out_features=96, bias=False)
          (dt_proj): Linear8bitLt(in_features=64, out_features=2048, bias=True)
          (out_proj): Linear4bit(in_features=2048, out_features=1024, bias=False)
        )
      )
    )
    (norm_f): MambaRMSNorm()
  )
  (lm_head): Linear(in_features=1024, out_features=50280, bias=False)
)

#Outlier Counting


In [None]:
#!pip install datasets
from datasets import load_dataset

test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt")

This loop prints the outlier stats for each hidden state, for each iteration of the perplexity testing loop

In [None]:
import torch
from tqdm import tqdm

device = "cuda"
max_length = 1024 #bit_model.config.n_positions
stride = 512
seq_len = encodings.input_ids.size(1)

nlls = []
prev_end_loc = 0

for begin_loc in tqdm(range(0, seq_len, stride)):
    end_loc = min(begin_loc + max_length, seq_len)
    trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100

    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)

        hidden_states = outputs.hidden_states
        for index, layers in enumerate(outputs.hidden_states):
          print("----------")
          print("layer number:",index)
          layer_weight = layers.detach().cpu().numpy()
          print(layer_weight.shape)
          mean = layer_weight.mean()
          std = layer_weight.std()
          num_outliers = 0
          num_cols_with_outliers = 0
          outlier_rows = set()
          outlier_cols = set()
          for i in range(1024):
            # num_outliers_in_ith = 0
            for j in range(1024):
              if abs(layer_weight[0][i][j] - mean) > 6.0 * std:
                num_outliers += 1
                if i not in outlier_rows:
                  outlier_rows.add(i)
                if j not in outlier_cols:
                  outlier_cols.add(j)
          print("num_outliers", num_outliers)
          print("num_rows_with_outliers", len(outlier_rows))
          print("num_cols_with_outliers", len(outlier_cols))
          print("mean: ", layer_weight.mean(), "std: ", layer_weight.std())
          print("max: ", layer_weight.max(), "min: ", layer_weight.min())
          print('---------')


        # loss is calculated using CrossEntropyLoss which averages over valid labels
        # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
        # to the left by 1.
        neg_log_likelihood = outputs.loss

    nlls.append(neg_log_likelihood)
    print("Current:",torch.exp(torch.stack(nlls).mean()))

    prev_end_loc = end_loc
    if end_loc == seq_len:
        break

ppl = torch.exp(torch.stack(nlls).mean())

print("perplexity", ppl)