# Snippet, how the method can be used

Link for this method

https://arxiv.org/pdf/2402.04291

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 1. Define a simple MLP
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.relu(self.fc1(x))
        return self.fc2(x)

# 2. Load MNIST data
transform = transforms.Compose([transforms.ToTensor()])
train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
test_loader = DataLoader(test_set, batch_size=1000, shuffle=False)

# 3. Train the MLP
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MLP().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(3):  # quick training
    model.train()
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1} done.")

# Evaluate before quantization
def test(model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            pred = model(x).argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    print(f"Accuracy: {correct/total:.4f}")

print("Before quantization:")
test(model)


Epoch 1 done.
Epoch 2 done.
Epoch 3 done.
Before quantization:
Accuracy: 0.9698


In [2]:
def collect_layer_inputs(model, data_loader, device, layer_name, num_batches=10):
    model.eval()
    inputs = []
    count = 0
    with torch.no_grad():
        for x, _ in data_loader:
            if count >= num_batches:
                break
            x = x.to(device)
            if layer_name == 'fc1':
                inp = x.view(x.size(0), -1)
            elif layer_name == 'fc2':
                x1 = model.relu(model.fc1(x.view(x.size(0), -1)))
                inp = x1
            else:
                raise ValueError("Unknown layer name")
            inputs.append(inp)
            count += 1
    return torch.cat(inputs, dim=0)


def compute_hessian_diag(inputs):
    # Approximate diagonal of Hessian: E[x^2]
    return torch.mean(inputs ** 2, dim=0)


def billm_quantize_linear_layer(layer, hessian_diag, salient_ratio=0.1):
    importance = hessian_diag.abs()
    threshold = torch.quantile(importance, 1 - salient_ratio)
    salient_mask = importance >= threshold
    non_salient_mask = ~salient_mask

    W = layer.weight.data.clone()
    for i in range(W.shape[0]):
        # Salient
        if salient_mask.sum() > 0:
            scale = W[i, salient_mask].abs().mean()
            W[i, salient_mask] = scale * torch.sign(W[i, salient_mask])
        # Non-salient
        W[i, non_salient_mask] = torch.sign(W[i, non_salient_mask])
    layer.weight.data = W


# Quantize both layers
print("Quantizing...")

with torch.no_grad():
    # For fc1
    fc1_inputs = collect_layer_inputs(model, train_loader, device, 'fc1')
    hess1 = compute_hessian_diag(fc1_inputs)
    billm_quantize_linear_layer(model.fc1, hess1, salient_ratio=0.1)
    # For fc2
    fc2_inputs = collect_layer_inputs(model, train_loader, device, 'fc2')
    hess2 = compute_hessian_diag(fc2_inputs)
    billm_quantize_linear_layer(model.fc2, hess2, salient_ratio=0.1)
print("Quantization done.")


Quantizing...
Quantization done.


In [3]:
print("After quantization:")
test(model)

After quantization:
Accuracy: 0.7757


In [5]:
def model_size_in_mb(model, bits_per_param=32):
    total_params = sum(p.numel() for p in model.parameters())
    total_bits = total_params * bits_per_param
    return total_bits / 8 / 1024 / 1024  # Convert bits to MB

print("Original model size (MB):", model_size_in_mb(model, bits_per_param=32))
print("Quantized model size (MB):", model_size_in_mb(model, bits_per_param=1))

Original model size (MB): 0.7764053344726562
Quantized model size (MB): 0.024262666702270508


Link for this method

https://proceedings.mlsys.org/paper_files/paper/2024/file/42a452cbafa9dd64e9ba4aa95cc1ef21-Paper-Conference.pdf

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np

# Simple 2-layer MLP for MNIST
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.fc1(x)
        self.act1 = x.detach()  # Save activations for AWQ
        x = self.relu(x)
        x = self.fc2(x)
        return x

def train(model, device, train_loader, optimizer, criterion, epochs=3):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")

def evaluate(model, device, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
    accuracy = correct / total
    return accuracy

def collect_activation_stats(model, device, calib_loader, max_batches=10):
    model.eval()
    activations = []
    with torch.no_grad():
        for i, (data, _) in enumerate(calib_loader):
            data = data.to(device)
            _ = model(data)
            activations.append(model.act1.cpu())
            if i + 1 >= max_batches:
                break
    activations = torch.cat(activations, dim=0)
    # Per output channel max absolute activation
    act_max = torch.max(torch.abs(activations), dim=0)[0]
    return act_max

def awq_quantize_weight(weight, act_max, num_bits=4, protect_ratio=0.01):
    w = weight.data.cpu()
    out_channels = w.size(0)

    k = max(1, int(protect_ratio * out_channels))
    _, topk_idx = torch.topk(act_max, k)

    scale = torch.ones_like(act_max)
    scale[topk_idx] = 2.0  # scale factor for protected channels

    scaled_w = w * scale.view(-1, 1)

    qmin = -2**(num_bits - 1)
    qmax = 2**(num_bits - 1) - 1
    max_val = scaled_w.abs().max()
    scale_factor = max_val / qmax + 1e-8

    w_q = torch.clamp((scaled_w / scale_factor).round(), qmin, qmax)
    w_deq = w_q * scale_factor

    w_deq[topk_idx, :] /= 2.0  # reverse scaling on protected channels

    weight.data.copy_(w_deq.to(weight.device))

def model_size_bytes(model):
    total_params = sum(p.numel() for p in model.parameters())
    # Assume FP32 (4 bytes) for full precision
    size_fp32 = total_params * 4
    return size_fp32

def model_size_bytes_quantized(model, bitwidth=4):
    total_params = sum(p.numel() for p in model.parameters())
    size_bits = total_params * bitwidth
    size_bytes = size_bits / 8
    return size_bytes

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Data loaders
    transform = transforms.ToTensor()
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

    model = MLP().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    print("Training full-precision model...")
    train(model, device, train_loader, optimizer, criterion, epochs=3)

    acc_fp32 = evaluate(model, device, test_loader)
    size_fp32 = model_size_bytes(model)
    print(f"Full-precision model accuracy: {acc_fp32*100:.2f}%")
    print(f"Full-precision model size: {size_fp32/1024:.2f} KB")

    # Collect activation stats for AWQ
    act_max = collect_activation_stats(model, device, test_loader)

    # Apply AWQ-inspired quantization on first layer weights
    awq_quantize_weight(model.fc1.weight, act_max, num_bits=4, protect_ratio=0.01)

    # Evaluate quantized model accuracy
    acc_awq = evaluate(model, device, test_loader)
    size_awq = model_size_bytes_quantized(model, bitwidth=4)
    print(f"AWQ 4-bit quantized model accuracy: {acc_awq*100:.2f}%")
    print(f"AWQ 4-bit quantized model size (approx): {size_awq/1024:.2f} KB")

if __name__ == "__main__":
    main()


Training full-precision model...
Epoch 1, Loss: 0.3041
Epoch 2, Loss: 0.1269
Epoch 3, Loss: 0.0855
Full-precision model accuracy: 97.65%
Full-precision model size: 795.04 KB
AWQ 4-bit quantized model accuracy: 97.40%
AWQ 4-bit quantized model size (approx): 99.38 KB


In [None]:
import os
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import math
from datasets import load_dataset

def get_model_size(model, model_name="model"):

    temp_dir = f"./temp_{model_name.replace('/', '_')}"
    os.makedirs(temp_dir, exist_ok=True)
    
    if hasattr(model, "save_pretrained"):
        model.save_pretrained(temp_dir)
    else:
        torch.save(model.state_dict(), os.path.join(temp_dir, "pytorch_model.bin"))
    
    total_size = 0
    for dirpath, _, filenames in os.walk(temp_dir):
        for f in filenames:
            fp = os.path.join(dirpath, f)
            total_size += os.path.getsize(fp)
    
    # Clean up
    for dirpath, _, filenames in os.walk(temp_dir):
        for f in filenames:
            os.remove(os.path.join(dirpath, f))
    os.rmdir(temp_dir)
    
    return total_size / (1024 * 1024)  # Convert to MB


def quantize_weights_awq(weight, scale, zero_point, bits=4):
    q_min = -2 ** (bits - 1)
    q_max = 2 ** (bits - 1) - 1
    weight = weight / scale.unsqueeze(1)
    weight_q = torch.round(weight) + zero_point.unsqueeze(1)
    return torch.clamp(weight_q, q_min, q_max)

def dequantize_weights_awq(weight_q, scale, zero_point):
    return (weight_q - zero_point.unsqueeze(1)) * scale.unsqueeze(1)

def find_scale_zero_point(weight, bits=4):
    q_min = -2 ** (bits - 1)
    q_max = 2 ** (bits - 1) - 1
    max_val = torch.max(torch.abs(weight), dim=1).values
    scale = max_val / (q_max - q_min)
    zero_point = torch.zeros_like(scale)
    return scale, zero_point


def prepare_calibration_dataset(num_samples=32, seq_len=512):
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
    dataset = load_dataset("wikitext", "wikitext-103-v1", split="test")

    samples = []
    for i in range(min(num_samples, len(dataset))):
        text = dataset[i]["text"]
        if text.strip():
            tokenizer.pad_token = tokenizer.eos_token
            inputs = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True, padding="max_length")
            samples.append(inputs["input_ids"][0])
    return samples


def apply_awq_to_model(model, calibration_data, bits=4):
    device = next(model.parameters()).device
    model.eval()

    max_len = max(tensor.size(0) for tensor in calibration_data)

    activation_scales = {}
    with torch.no_grad():
        for input_ids in tqdm(calibration_data, desc="Collecting activations"):
            input_ids = input_ids.unsqueeze(0).to(device)
            attention_mask = torch.ones_like(input_ids)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True
            )
            
            for i, hidden_state in enumerate(outputs.hidden_states):
                act_scale = hidden_state.abs().amax(dim=(0,1))
                if i in activation_scales:
                    activation_scales[i] = torch.maximum(activation_scales[i], act_scale)
                else:
                    activation_scales[i] = act_scale
    

    quantized_layers = {}
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            weight = module.weight.data
            

            try:
                parts = name.split('.')
                if 'h.' in name:
                    layer_idx = int(parts[parts.index('h') + 1])
                else:
                    layer_idx = int([p for p in parts if p.isdigit()][-1])
            except:
                layer_idx = 0
            
            act_scale = activation_scales.get(layer_idx, torch.ones(weight.shape[0], device=device))
            if act_scale.shape[0] != weight.shape[0]:
                act_scale = torch.ones(weight.shape[0], device=device)
            
            with torch.no_grad():
                weight_fp32 = weight.float()
                scaling_factor = act_scale.sqrt().view(-1, 1)
                scaled_weights = weight_fp32 * scaling_factor
                
                scale, zero_point = find_scale_zero_point(scaled_weights, bits=bits)
                weight_q = quantize_weights_awq(weight_fp32, scale, zero_point, bits=bits)
                
                quantized_layers[name] = {
                    'weight_q': weight_q.to(torch.int8),
                    'scale': scale.to(weight.dtype),
                    'zero_point': zero_point.to(torch.int8),
                    'original_shape': weight.shape
                }
    
    return quantized_layers

def calculate_perplexity(model, tokenizer, test_texts, device, max_length=512):
    model.eval()
    total_loss = 0
    total_tokens = 0
    
    with torch.no_grad():
        for text in tqdm(test_texts, desc="Calculating perplexity"):
            tokenizer.pad_token = tokenizer.eos_token
            inputs = tokenizer(
                text, 
                return_tensors="pt", 
                max_length=max_length, 
                truncation=True,
                padding="max_length"  # Ensure consistent length
            )
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            outputs = model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss
            total_loss += loss.item() * inputs["input_ids"].numel()
            total_tokens += inputs["input_ids"].numel()
    
    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)
    return perplexity

def load_test_dataset(num_samples=32):
    dataset = load_dataset("wikitext", "wikitext-103-v1", split="test")
    return [dataset[i]["text"] for i in range(min(num_samples, len(dataset))) if dataset[i]["text"].strip()]


def save_quantized_model(model, quantized_layers, tokenizer, save_path):
    model.save_pretrained(save_path)
    tokenizer.save_pretrained(save_path)

class QuantizedLinear(nn.Module):
    def __init__(self, weight_q, scale, zero_point, original_shape):
        super().__init__()
        self.register_buffer('weight_q', weight_q)
        self.register_buffer('scale', scale)
        self.register_buffer('zero_point', zero_point)
        self.out_features, self.in_features = original_shape
        
    def forward(self, x):
        weight = dequantize_weights_awq(self.weight_q, self.scale, self.zero_point)
        return nn.functional.linear(x, weight)



In [None]:


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    try:
        print("Loading model...")
        model = AutoModelForCausalLM.from_pretrained(
            "EleutherAI/gpt-neo-2.7B", 
            torch_dtype=torch.float16
        ).to(device)
        tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
        
        orig_size = get_model_size(model, "original_model")
        print(f"Original model size: {orig_size:.2f} MB")
        
        print("Preparing data...")
        calibration_data = prepare_calibration_dataset(num_samples=32)
        test_texts = [t for t in load_test_dataset()]
        print("Calculating original perplexity...")
        orig_perplexity = calculate_perplexity(model, tokenizer, test_texts, device)
        print(f"Original perplexity: {orig_perplexity:.2f}")
        
        # Quantization
        print("Applying AWQ...")
        quantized_layers = apply_awq_to_model(model, calibration_data, bits=4)
        
        for name, module in model.named_modules():
            if name in quantized_layers:
                q_data = quantized_layers[name]
                new_layer = QuantizedLinear(
                    q_data['weight_q'],
                    q_data['scale'],
                    q_data['zero_point'],
                    q_data['original_shape']
                )
                parent = model
                parts = name.split('.')
                for part in parts[:-1]:
                    parent = getattr(parent, part)
                setattr(parent, parts[-1], new_layer)
        
        quant_size = get_model_size(model, "quantized_model")
        print(f"Quantized model size: {quant_size:.2f} MB")
        print(f"Size reduction: {orig_size - quant_size:.2f} MB ({quant_size/orig_size*100:.1f}% of original)")
        
        print("Calculating quantized perplexity...")
        quant_perplexity = calculate_perplexity(model, tokenizer, test_texts, device)
        print(f"Quantized perplexity: {quant_perplexity:.2f}")
        
        print("Saving model...")
        save_quantized_model(model, quantized_layers, tokenizer, "quantized_gpt_neo")
        
        print("\n=== Results ===")
        print(f"Original size: {orig_size:.2f} MB | Perplexity: {orig_perplexity:.2f}")
        print(f"Quantized size: {quant_size:.2f} MB | Perplexity: {quant_perplexity:.2f}")
        print(f"Size reduction: {orig_size - quant_size:.2f} MB ({quant_size/orig_size*100:.1f}% of original)")
        print(f"Perplexity difference: {quant_perplexity - orig_perplexity:.2f} (+{(quant_perplexity - orig_perplexity)/orig_perplexity*100:.1f}%)")
    
    except Exception as e:
        print(f"Error: {str(e)}")
        raise

if __name__ == "__main__":
    main()

Loading model...
Original model size: 5057.05 MB
Preparing data...
Calculating original perplexity...


Calculating perplexity: 100%|██████████| 14/14 [00:00<00:00, 25.52it/s]


Original perplexity: 54.72
Applying AWQ...


Collecting activations: 100%|██████████| 14/14 [00:00<00:00, 26.13it/s]


Quantized model size: 2781.06 MB
Size reduction: 2275.99 MB (55.0% of original)
Calculating quantized perplexity...


Calculating perplexity: 100%|██████████| 14/14 [00:00<00:00, 16.37it/s]


Quantized perplexity: 61.30
Saving model...

=== Results ===
Original size: 5057.05 MB | Perplexity: 54.72
Quantized size: 2781.06 MB | Perplexity: 61.30
Size reduction: 2275.99 MB (55.0% of original)
Perplexity difference: 6.58 (+12.0%)
