In [33]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Step 1: Load the quantized model and tokenizer
# MODEL_ID = "TinyLlama-1.1B-Chat-v1.0-Smooth-GPTQ-W8A8-Dynamic-Per-Token"
# MODEL_ID = "TinyLlama-1.1B-Chat-v1.0-Smooth-GPTQ-ASYM-W8A8-Dynamic-Per-Token"
MODEL_ID = "TinyLlama-1.1B-Chat-v1.0-Smooth-GPTQ-FP8_DYNAMIC-Per-Token"
# MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
# model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16)
# model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

In [34]:
# Setup input collection for model analysis
import torch.nn as nn
from compressed_tensors.linear.compressed_linear import CompressedLinear

# Check for quantized layers and setup input collection accordingly
has_quantized = any(isinstance(m, CompressedLinear) for _, m in model.named_modules())

if has_quantized:
    print("Using quantized model with CompressedLinear layers")
else:
    # Add input collection hooks to Linear layers (excluding output layer)
    def collect_inputs(module, inputs, _):
        if not hasattr(module, 'inputs'):
            module.inputs = []
        module.inputs.append(inputs[0].detach().clone())
    
    linear_layers = [(n, m) for n, m in model.named_modules() 
                    if isinstance(m, nn.Linear) and 'lm_head' not in n]
    
    hooks = [module.register_forward_hook(collect_inputs) for _, module in linear_layers]
    print(f"Added input collection to {len(hooks)} Linear layers")

Using quantized model with CompressedLinear layers


In [35]:
from datasets import load_dataset

# Step 2: Prepare Calibration Data
NUM_CALIBRATION_SAMPLES=4
MAX_SEQUENCE_LENGTH=2048

# Load dataset.
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split=f"train_sft[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)

# Preprocess the data into the format the model is trained with.
def preprocess(example):
    return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False,)}
ds = ds.map(preprocess)

# Tokenize the data (be careful with bos tokens - we need add_special_tokens=False since the chat_template already added it).
def tokenize(sample):
    return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
ds = ds.map(tokenize, remove_columns=ds.column_names)

In [36]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

# Example forward pass to trigger the hook
for i, sample in enumerate(ds):
    with torch.no_grad():
        input_ids = torch.tensor(sample["input_ids"]).unsqueeze(0).to(device)
        attention_mask = torch.tensor(sample["attention_mask"]).unsqueeze(0).to(device)
        print(f"Processing sample: {i}, length: {input_ids.shape}")
        _ = model(input_ids=input_ids)

Processing sample: 0, length: torch.Size([1, 499])
Processing sample: 1, length: torch.Size([1, 2048])
Processing sample: 1, length: torch.Size([1, 2048])
Processing sample: 2, length: torch.Size([1, 1799])
Processing sample: 2, length: torch.Size([1, 1799])
Processing sample: 3, length: torch.Size([1, 787])
Processing sample: 3, length: torch.Size([1, 787])


In [37]:
import numpy as np
from compressed_tensors.linear.compressed_linear import CompressedLinear
import torch

def to_native_numpy(tensor):
    """Convert tensor to numpy array, preserving original dtype when possible"""
    if tensor.dtype == torch.bfloat16:
        return tensor.float().cpu().numpy()
    elif tensor.dtype == torch.float16:
        return tensor.cpu().numpy()
    elif tensor.dtype == torch.float32:
        return tensor.cpu().numpy()
    else:
        return tensor.float().cpu().numpy()

# Determine data format from MODEL_ID (global variables)
is_fp8 = "FP8" in MODEL_ID
is_int8 = "W8A8" in MODEL_ID and not is_fp8

layer_distributions = {}

for name, module in model.named_modules():
    if isinstance(module, CompressedLinear) or (hasattr(module, 'inputs') and hasattr(module, 'weight')):
        if isinstance(module, CompressedLinear):
            weight = module.weight.data.detach()
            weight_scale = module.weight_scale.detach()
            weight_int8 = (weight / weight_scale).detach().int().cpu().numpy().flatten()
            
            inputs = np.concatenate([
                to_native_numpy(inp.flatten())
                for inp in module.inputs
            ])
            inputs_int8 = np.concatenate([
                inp.flatten().cpu().numpy()
                for inp in module.quantized_inputs
            ])
            
            layer_distributions[name] = {
                "weight_scale": to_native_numpy(weight_scale),
                "weight_int8": weight_int8,
                "inputs": inputs,
                "inputs_scales": [to_native_numpy(s) for s in module.input_scales],
                "inputs_int8": inputs_int8,
            }
        
        elif hasattr(module, 'inputs'):
            inputs = np.concatenate([
                to_native_numpy(inp.flatten())
                for inp in module.inputs
            ])
            
            weight = to_native_numpy(module.weight.data.flatten())
            
            layer_distributions[name] = {
                "weight": weight,
                "inputs": inputs,
            }

In [38]:
import os
import csv

os.makedirs(f"output", exist_ok=True)

csv_rows = [
    [
        "layer",
        "weight_zero_pct", "weight_neg1_pct", "weight_pos1_pct",
        "input_zero_pct", "input_neg1_pct", "input_pos1_pct",
        "input_scale_min", "input_scale_max", "input_scale_mean"
    ]
]

for k, v in layer_distributions.items():
    # Handle weights - use weight_int8 if available, otherwise use weight
    if "weight_int8" in v:
        weight = v["weight_int8"]
        # Weight stats for quantized weights (exact integer values)
        weight_zero_pct = (weight == 0).sum().item() / weight.size
        weight_neg1_pct = (weight == -1).sum().item() / weight.size
        weight_pos1_pct = (weight == 1).sum().item() / weight.size
    elif "weight" in v:
        weight = v["weight"]
        # For float weights, direct comparison (exact values are rare in neural networks)
        weight_zero_pct = (weight == 0.0).sum().item() / weight.size
        weight_neg1_pct = (weight == -1.0).sum().item() / weight.size
        weight_pos1_pct = (weight == 1.0).sum().item() / weight.size
    else:
        weight_zero_pct = ""
        weight_neg1_pct = ""
        weight_pos1_pct = ""

    # Handle inputs - use inputs_int8 if available, otherwise use inputs
    if "inputs_int8" in v:
        inputs = v["inputs_int8"]
        # Input stats for quantized inputs (exact integer values)
        input_zero_pct = (inputs == 0).sum().item() / inputs.size
        input_neg1_pct = (inputs == -1).sum().item() / inputs.size
        input_pos1_pct = (inputs == 1).sum().item() / inputs.size
    elif "inputs" in v:
        inputs = v["inputs"]
        # For float inputs, direct comparison
        input_zero_pct = (inputs == 0.0).sum().item() / inputs.size
        input_neg1_pct = (inputs == -1.0).sum().item() / inputs.size
        input_pos1_pct = (inputs == 1.0).sum().item() / inputs.size
    else:
        input_zero_pct = ""
        input_neg1_pct = ""
        input_pos1_pct = ""

    # Handle input scales
    if "inputs_scales" in v and v["inputs_scales"]:
        # Handle both list of arrays and single arrays
        if isinstance(v["inputs_scales"], list):
            inp_scales = np.concatenate([s.flatten() if hasattr(s, 'flatten') else s 
                                       for s in v["inputs_scales"]])
        else:
            inp_scales = v["inputs_scales"].flatten()
        
        inp_scales_inv = 1.0 / (inp_scales + 1e-8)  # Add small epsilon to avoid division by zero
        input_scale_min = inp_scales_inv.min()
        input_scale_max = inp_scales_inv.max()
        input_scale_mean = inp_scales_inv.mean()
    else:
        input_scale_min = ""
        input_scale_max = ""
        input_scale_mean = ""

    csv_rows.append([
        k,
        weight_zero_pct, weight_neg1_pct, weight_pos1_pct,
        input_zero_pct, input_neg1_pct, input_pos1_pct,
        input_scale_min, input_scale_max, input_scale_mean
    ])

with open(f"output/{MODEL_ID.split('/')[-1]}_layer_distribution.csv", "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerows(csv_rows)

In [39]:
import os
import matplotlib.pyplot as plt

os.makedirs(f"figures/{MODEL_ID.split('/')[-1]}", exist_ok=True)

for layer_name, v in layer_distributions.items():
    plt.figure(figsize=(12, 5))
    
    # Handle weights
    plt.subplot(1, 2, 1)
    if "weight_int8" in v:
        weight = v["weight_int8"]
        weights = np.full(len(weight), 100.0 / len(weight), dtype=np.float32)
        
        if is_fp8:
            # FP8 E4M3 format
            plt.hist(weight, bins=897, range=(-448, 448), color='blue', alpha=0.7, weights=weights)
            plt.title(f"{layer_name} Weight Distribution (FP8)")
        else:
            # INT8 format
            plt.hist(weight, bins=256, range=(-128, 127), color='blue', alpha=0.7, weights=weights)
            plt.title(f"{layer_name} Weight Distribution (INT8)")
        plt.xlabel("Quantized Value")
        
    elif "weight" in v:
        weight = v["weight"]
        weights = np.full(len(weight), 100.0 / len(weight), dtype=np.float32)
        plt.hist(weight, bins=1000, color='blue', alpha=0.7, weights=weights)
        plt.title(f"{layer_name} Weight Distribution (Float)")
        plt.xlabel("Float Value")
    else:
        plt.text(0.5, 0.5, 'No weight data', ha='center', va='center', transform=plt.gca().transAxes)
        plt.title(f"{layer_name} Weight Distribution (No Data)")
        plt.xlabel("Value")
    
    plt.ylabel("Percentage (%)")
    
    # Handle inputs
    plt.subplot(1, 2, 2)
    if "inputs_int8" in v:
        inputs = v["inputs_int8"]
        weights = np.full(len(inputs), 100.0 / len(inputs), dtype=np.float32)
        
        if is_fp8:
            # FP8 E4M3 format
            plt.hist(inputs, bins=897, range=(-448, 448), color='green', alpha=0.7, weights=weights)
            plt.title(f"{layer_name} Input Distribution (FP8)")
        else:
            # INT8 format
            plt.hist(inputs, bins=256, range=(-128, 127), color='green', alpha=0.7, weights=weights)
            plt.title(f"{layer_name} Input Distribution (INT8)")
        plt.xlabel("Quantized Value")
        
    elif "inputs" in v:
        inputs = v["inputs"]
        weights = np.full(len(inputs), 100.0 / len(inputs), dtype=np.float32)
        plt.hist(inputs, bins=1000, color='green', alpha=0.7, weights=weights)
        plt.title(f"{layer_name} Input Distribution (Float)")
        plt.xlabel("Float Value")
    else:
        plt.text(0.5, 0.5, 'No input data', ha='center', va='center', transform=plt.gca().transAxes)
        plt.title(f"{layer_name} Input Distribution (No Data)")
        plt.xlabel("Value")
    
    plt.ylabel("Percentage (%)")
    
    plt.tight_layout()
    plt.savefig(f"figures/{MODEL_ID.split('/')[-1]}/{layer_name.replace('/', '_')}_distributions.png", dpi=300)
    plt.close()