In [1]:
model_id = "ai21labs/Jamba-v0.1"

In [2]:
import torch
import transformers
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import bitsandbytes as bnb
print("torch", torch.__version__)
print("transformers", transformers.__version__)

device_map = "sequential"
max_memory = {0:"20GiB", 1: "22GiB"}


quantization_config = BitsAndBytesConfig(load_in_4bit=True,
                                         llm_int8_skip_modules=["lm_head", "mamba"])

model = AutoModelForCausalLM.from_pretrained(model_id,
                                             trust_remote_code=True,
                                             torch_dtype=torch.bfloat16,
                                             attn_implementation="flash_attention_2",
                                             device_map=device_map,
                                             max_memory=max_memory,
                                             quantization_config=quantization_config)
model

torch 2.1.2+cu121
transformers 4.39.2


Loading checkpoint shards:   0%|          | 0/21 [00:00<?, ?it/s]

JambaForCausalLM(
  (model): JambaModel(
    (embed_tokens): Embedding(65536, 4096, padding_idx=0)
    (layers): ModuleList(
      (0): JambaMambaDecoderLayer(
        (mamba): JambaMambaMixer(
          (conv1d): Conv1d(8192, 8192, kernel_size=(4,), stride=(1,), padding=(3,), groups=8192)
          (act): SiLU()
          (in_proj): Linear(in_features=4096, out_features=16384, bias=False)
          (x_proj): Linear(in_features=8192, out_features=288, bias=False)
          (dt_proj): Linear(in_features=256, out_features=8192, bias=True)
          (out_proj): Linear(in_features=8192, out_features=4096, bias=False)
          (dt_layernorm): JambaRMSNorm()
          (B_layernorm): JambaRMSNorm()
          (C_layernorm): JambaRMSNorm()
        )
        (moe): JambaSparseMoeBlock(
          (experts): ModuleList(
            (0): JambaMLP(
              (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
              (down_proj): Linear4bit(in_features=14336, out_fea

In [3]:
import bitsandbytes.functional as F
from bitsandbytes.nn import Linear4bit
from torch.nn.modules.linear import Linear
import numpy as np
import gc

layer_snr = {}

def marchenko_pastur_threshold(sigma, n, m):
    beta = n / m if n < m else m / n
    threshold = sigma * np.sqrt((1 + np.sqrt(beta)) ** 2)
    return threshold

def estimate_sigma_with_full_iqr(S):
    q75 = torch.quantile(S, 0.75)
    q25 = torch.quantile(S, 0.25)
    iqr = q75 - q25
    sigma_estimated = (
        iqr / 1.349
    )  ## 0.6745 * sigma is the expected range between the quantiles (Q1 and Q3)
    return sigma_estimated

def calculate_snr_for_layer(module):
    if isinstance(module, Linear4bit):
        weights = F.dequantize_4bit(module.weight, module.quant_state).t().double()
    elif isinstance(module, Linear):
        weights = module.weight.double()
    else:
        assert "not Linear"
    S = torch.linalg.svdvals(weights)
    weights = weights.detach().cpu()
    S = S.detach().cpu()
    sigma_estimated = estimate_sigma_with_full_iqr(S)
    n, m = weights.shape
    mp_threshold = marchenko_pastur_threshold(sigma_estimated, n, m)
    
    signal = S[S > mp_threshold].sum()
    noise = S[S <= mp_threshold].sum()
    print("signal", signal, "noise", noise)
    snr = signal / noise if noise != 0 else float("inf")
    del S, weights
    torch.cuda.empty_cache()  # Clear PyTorch's CUDA memory cache
    gc.collect()
    return snr

def assess_layers_snr(model):
    for name, module in model.named_modules():
        if isinstance(module, Linear) or isinstance(module, Linear4bit):
            print("*" * 50, flush=True)
            print(
                f"Calculating Signal to Noise Ratio at layer {name}",
                flush=True,
            )
            snr = calculate_snr_for_layer(module)
            layer_snr[name] = snr
            print(
                f"Signal to Noise Ratio at layer {name} = {snr}",
                flush=True,
            )
            print("*" * 50, flush=True)

In [4]:
assess_layers_snr(model)

**************************************************
Calculating Signal to Noise Ratio at layer model.layers.0.mamba.in_proj
signal tensor(4593.8420, dtype=torch.float64) noise tensor(494.2332, dtype=torch.float64)
Signal to Noise Ratio at layer model.layers.0.mamba.in_proj = 9.294888408316409
**************************************************
**************************************************
Calculating Signal to Noise Ratio at layer model.layers.0.mamba.x_proj
signal tensor(195.6183, dtype=torch.float64) noise tensor(19.8755, dtype=torch.float64)
Signal to Noise Ratio at layer model.layers.0.mamba.x_proj = 9.84220223913894
**************************************************
**************************************************
Calculating Signal to Noise Ratio at layer model.layers.0.mamba.dt_proj
signal tensor(222.5460, dtype=torch.float64) noise tensor(29.6763, dtype=torch.float64)
Signal to Noise Ratio at layer model.layers.0.mamba.dt_proj = 7.4991185326259915
*************************

In [5]:
import json

def save_layers_to_json(layer_snr, filename):
    with open(filename, 'w') as file:
        serializable_data = {}
        for key, value in layer_snr.items():
            # Convert Tensors to Python numbers (for SNR) and handle other data types as needed
            snr_value = value.item() if isinstance(value, torch.Tensor) else value
            # module_str = str(value['module'])  # Assuming module representation is a string or convertible to a string
            serializable_data[key] = {'snr': snr_value}

        json.dump(serializable_data, file, indent=4)

save_layers_to_json(layer_snr, "laser_Jamba-v0.1_layer_snr_info.json")

In [28]:
def select_layers_for_modification(layer_snr, k):
    sorted_layers = sorted(
        layer_snr.items(), key=lambda x: x[1], reverse=False
    )
    return [layer[0] for layer in sorted_layers[:k]]

top_k_layers = select_layers_for_modification(layer_snr, 16)
print(top_k_layers, flush=True)

['model.layers.28.self_attn.o_proj', 'model.layers.20.self_attn.o_proj', 'model.layers.12.self_attn.o_proj', 'model.layers.4.self_attn.o_proj', 'model.layers.12.self_attn.q_proj', 'model.layers.20.self_attn.q_proj', 'model.layers.4.self_attn.k_proj', 'model.layers.4.self_attn.q_proj', 'model.layers.28.self_attn.q_proj', 'model.layers.18.mamba.out_proj', 'model.layers.28.self_attn.v_proj', 'model.layers.11.mamba.out_proj', 'model.layers.17.mamba.out_proj', 'model.layers.22.mamba.out_proj', 'model.layers.19.mamba.out_proj', 'model.layers.10.mamba.out_proj']
