# SAM Quantization Example Usage

This notebook demonstrates how to use the SAM quantization utilities, including SmoothQuant weight comparison and A8W8 quantization.

## Overview

This guide covers two main functionalities:
1. **Smooth SAM and SAM SmoothQuant Weight Comparison**: Run the Smooth_sam.py to get smoothed checkpoint. Compare original vs smoothed SAM model weights
2. **A8W8 Quantization**: Convert Linear layers to 8-bit weight and activation quantization

## 1. Setup and Imports

Import necessary libraries and utilities for SAM model quantization.

In [5]:
import torch
import torch.nn as nn
import os
import sys
import sys
sys.path.append('/path/to/SAM_Quantization')
# Import our custom quantization utilities
from utils import (
    sam_smoothing_test,
    replace_linear_with_target_and_quantize,
    smooth_sam
)
from per_tensor_channel_group import W8A8Linear
from segment_anything import sam_model_registry, SamPredictor
CUDA_VISIBLE_DEVICES = os.environ.get('CUDA_VISIBLE_DEVICES', '2')

print("All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

All imports successful!
PyTorch version: 2.6.0+cu118
CUDA available: True


## 2. Smooth SAM and SAM Model Weight Comparison with SmoothQuant

### Smooth SAM
Run the file Smooth_sam.py to get checkpoint files and weights for the Smooth SAM model.

### Weight Comparison

In [6]:
# Configuration for SAM model paths
sam_checkpoint = "/home/ubuntu/21chi.nh/Quantization/SAM_Quantization/SAM_Quantization/checkpoint_sam/sam_hq_vit_l.pth"  # Original SAM checkpoint
smoothed_sam_checkpoint = "/home/ubuntu/21chi.nh/Quantization/SAM_Quantization/SAM_Quantization/checkpoint_sam/smoothed_vit_l_sam.pth"  # Pre-smoothed SAM checkpoint
model_type = "vit_l"  # SAM model type (vit_b, vit_l, vit_h)
device = "cuda" if torch.cuda.is_available() else "cpu"


In [7]:
# Run SAM smoothing test to compare weights
print("=" * 60)
print("RUNNING SAM SMOOTHING WEIGHT COMPARISON TEST")
print("=" * 60)
print
print
sam_smoothing_test(
    sam_checkpoint=sam_checkpoint,
    smoothed_sam_checkpoint=smoothed_sam_checkpoint, 
    model_type=model_type,
    device=device
)

print("\n✓ Weight comparison completed!")

RUNNING SAM SMOOTHING WEIGHT COMPARISON TEST
/home/ubuntu/21chi.nh/Quantization/SAM_Quantization/SAM_Quantization/checkpoint_sam/sam_hq_vit_l.pth


FileNotFoundError: [Errno 2] No such file or directory: '/home/ubuntu/21chi.nh/Quantization/SAM_Quantization/SAM_Quantization/checkpoint_sam/sam_hq_vit_l.pth'

### Expected Output Explanation:

The output shows weight comparisons for the first 2 transformer blocks in the SAM image encoder:

- **Norm1/Norm2 weights**: LayerNorm parameters before attention/MLP layers
- **QKV weights**: Query, Key, Value projection weights in attention mechanism
- **MLP weights**: First linear layer in MLP (Feed-Forward) blocks
- **Weight change**: L2 norm showing magnitude of parameter changes after smoothing

**Key insight**: SmoothQuant modifies LayerNorm scales and corresponding linear layer weights to balance quantization difficulty between weights and activations.

### What the numbers mean:
- **Small norm changes (< 0.1)**: Indicates gentle smoothing
- **Larger weight changes**: Shows where smoothing had significant impact
- **Zero changes**: No smoothing was applied to that layer

## 3. A8W8 Quantization with replace_linear_with_target_and_quantize

### Purpose:
The `replace_linear_with_target_and_quantize()` function converts standard PyTorch Linear layers to quantized A8W8 (8-bit weights, 8-bit activations) layers throughout a model.


### 3.1 Simple Example with Dummy Model

Let's start with a simple example using a dummy model to understand how the quantization works.

In [None]:
# Define a simple model for demonstration
class DummyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Embedding(1000, 128)  # Embedding layer
        self.linear_1 = nn.Linear(128, 256)  # First linear layer
        self.linear_2 = nn.Linear(256, 128, bias=False)  # Second linear layer (no bias)
        self.lm_head = nn.Linear(128, 1000, bias=False)  # Output head
    
    def forward(self, x):
        x = self.emb(x)
        x = self.linear_1(x)
        x = torch.relu(x)
        x = self.linear_2(x)
        x = torch.relu(x)
        x = self.lm_head(x)
        return x

# Create and inspect the model
model = DummyModel()
print("Before quantization:")
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

# List all linear layers
print("\nLinear layers before quantization:")
for name, module in model.named_modules():
    if isinstance(module, nn.Linear):
        print(f"  {name}: {type(module).__name__} - {module.weight.shape}")

Before quantization:
DummyModel(
  (emb): Embedding(1000, 128)
  (linear_1): Linear(in_features=128, out_features=256, bias=True)
  (linear_2): Linear(in_features=256, out_features=128, bias=False)
  (lm_head): Linear(in_features=128, out_features=1000, bias=False)
)

Total parameters: 321,792

Linear layers before quantization:
  linear_1: Linear - torch.Size([256, 128])
  linear_2: Linear - torch.Size([128, 256])
  lm_head: Linear - torch.Size([1000, 128])


In [None]:
# Apply A8W8 quantization to the model
# We exclude 'emb' (embedding) and 'lm_head' (output head) from quantization
print("Applying A8W8 quantization...")

replace_linear_with_target_and_quantize(
    module=model, 
    target_class=W8A8Linear, 
    module_name_to_exclude=["emb", "lm_head"],  # Skip embedding and output head
    weight_quant="per_channel",  # Per-channel weight quantization
    act_quant="per_token"       # Per-token activation quantization
)

print("\nAfter quantization:")
print(model)

# List all layers after quantization
print("\nLayers after quantization:")
for name, module in model.named_modules():
    if isinstance(module, (nn.Linear, W8A8Linear)):
        layer_type = "QUANTIZED" if isinstance(module, W8A8Linear) else "ORIGINAL"
        print(f"  {name}: {type(module).__name__} ({layer_type})")

Applying A8W8 quantization...

After quantization:
DummyModel(
  (emb): Embedding(1000, 128)
  (linear_1): W8A8Linear(128, 256, bias=True, weight_quant=per_channel, act_quant=per_token, output_quant=None)
  (linear_2): W8A8Linear(256, 128, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)
  (lm_head): Linear(in_features=128, out_features=1000, bias=False)
)

Layers after quantization:
  linear_1: W8A8Linear (QUANTIZED)
  linear_2: W8A8Linear (QUANTIZED)
  lm_head: Linear (ORIGINAL)


In [None]:
# Test forward pass to ensure the quantized model works
print("Testing quantized model...")

# Create a sample input
test_input = torch.randint(0, 1000, (2, 10))  # Batch size 2, sequence length 10
print(f"Input shape: {test_input.shape}")

# Forward pass
with torch.no_grad():
    output = model(test_input)
    print(f"Output shape: {output.shape}")
    print(f"Output range: [{output.min():.3f}, {output.max():.3f}]")

print("✓ Forward pass successful with quantized model!")

Testing quantized model...
Input shape: torch.Size([2, 10])
Output shape: torch.Size([2, 10, 1000])
Output range: [-0.406, 0.390]
✓ Forward pass successful with quantized model!


### 3.2 Applying A8W8 Quantization to SAM Model

Now let's apply quantization to a real SAM model. This is more complex due to SAM's architecture with image encoder, prompt encoder, and mask decoder components.

In [None]:
sam_model_checkpoint = "/home/ubuntu/21chi.nh/Quantization/SAM_Quantization/SAM_Quantization/checkpoint_sam/sam_hq_vit_l.pth"
sam_model_type = "vit_l"

print(f"Loading SAM model: {sam_model_type}")
sam_model = sam_model_registry[sam_model_type](checkpoint=sam_model_checkpoint)
sam_model.to(device)
sam_model.eval()
# print out the architecture of model.mask_decoder

print("✓ SAM model loaded successfully!")

# Show first 2 attention layers before quantization
print("\n" + "="*50)
print("BEFORE QUANTIZATION")
print("="*50)

for i in range(2):
    if hasattr(sam_model.image_encoder, 'blocks') and i < len(sam_model.image_encoder.blocks):
        block = sam_model.image_encoder.blocks[i]
        
        # Check QKV layer
        if hasattr(block, 'attn') and hasattr(block.attn, 'qkv'):
            qkv_type = type(block.attn.qkv).__name__
            print(f"Layer {i} QKV: {qkv_type}")
        
        # Check MLP layer  
        if hasattr(block, 'mlp'):
            mlp_layer = getattr(block.mlp, 'lin1', getattr(block.mlp, 'fc1', None))
            if mlp_layer:
                mlp_type = type(mlp_layer).__name__
                print(f"Layer {i} MLP: {mlp_type}")

Loading SAM model: vit_l
<All keys matched successfully>
✓ SAM model loaded successfully!

BEFORE QUANTIZATION
Layer 0 QKV: Linear
Layer 0 MLP: Linear
Layer 1 QKV: Linear
Layer 1 MLP: Linear


In [None]:
# Apply A8W8 quantization to SAM model
print("\n" + "="*50)
print("APPLYING QUANTIZATION")
print("="*50)

# More targeted exclusion list for SAM-HQ
modules_to_exclude = [
    "pos_embed", "cls_token", "patch_embed", 
    "neck", "fpn", "mask_tokens", "iou_token", 
    "output_upscaling", "output_hypernetworks_mlps"
]

# Apply quantization
replace_linear_with_target_and_quantize(
    module=sam_model,
    target_class=W8A8Linear,
    module_name_to_exclude=modules_to_exclude,
    weight_quant="per_channel",    
    act_quant="per_token",           
    quantize_output=False
)

print("✓ Quantization completed!")


APPLYING QUANTIZATION
✓ Quantization completed!


In [None]:
print("\n" + "="*50)
print("AFTER QUANTIZATION")
print("="*50)

# Count quantized vs original layers
quantized_layers = sum(1 for _, m in sam_model.named_modules() if isinstance(m, W8A8Linear))
linear_layers = sum(1 for _, m in sam_model.named_modules() if isinstance(m, nn.Linear))

print(f"Quantized layers: {quantized_layers}")
print(f"Original layers: {linear_layers}")

# Show first 2 attention layers
for i in range(2):
    if hasattr(sam_model.image_encoder, 'blocks') and i < len(sam_model.image_encoder.blocks):
        block = sam_model.image_encoder.blocks[i]
        
        # Check QKV layer
        if hasattr(block, 'attn') and hasattr(block.attn, 'qkv'):
            qkv_type = type(block.attn.qkv).__name__
            status = "✓ Quantized" if isinstance(block.attn.qkv, W8A8Linear) else "⚠️ Not quantized"
            print(f"Layer {i} QKV: {qkv_type} ({status})")
        
        # Check MLP layer
        if hasattr(block, 'mlp'):
            mlp_layer = getattr(block.mlp, 'lin1', getattr(block.mlp, 'fc1', None))
            if mlp_layer:
                mlp_type = type(mlp_layer).__name__
                status = "✓ Quantized" if isinstance(mlp_layer, W8A8Linear) else "⚠️ Not quantized"
                print(f"Layer {i} MLP: {mlp_type} ({status})")

if quantized_layers + linear_layers > 0:
    ratio = quantized_layers / (quantized_layers + linear_layers) * 100
    print(f"\nQuantization ratio: {ratio:.1f}%")


AFTER QUANTIZATION
Quantized layers: 146
Original layers: 0
Layer 0 QKV: W8A8Linear (✓ Quantized)
Layer 0 MLP: W8A8Linear (✓ Quantized)
Layer 1 QKV: W8A8Linear (✓ Quantized)
Layer 1 MLP: W8A8Linear (✓ Quantized)

Quantization ratio: 100.0%
