## start

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np

import copy
from transformers import BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import os
import json

import matplotlib.pyplot as plt
import time
from collections import defaultdict
import tqdm

# 1. get linear mask for effective weight with each weight size [output_size, input_size]

In [2]:
def get_linear_mask(module:nn.Module) -> torch.Tensor:
    x = module.weight.data
    output_size, input_size = x.shape
    x_norm = torch.abs(x) / torch.sum(torch.abs(x), dim=0, keepdim=True)
    neff = torch.floor(1/torch.sum((x_norm ** 2), dim=0, keepdim=True).squeeze(0))
    
    _, indices = torch.sort(x_norm, dim=0, descending=True)
    range_tensor = torch.arange(output_size, device=x.device).unsqueeze(0).expand(input_size, -1).T
    sorted_mask = range_tensor < neff
    
    mask = torch.zeros_like(x, dtype=torch.bool)
    mask.scatter_(0, indices, sorted_mask)
    return mask

# 2. set the edge with ineffective weight = 0 and prune the edge

In [6]:
def prune_model_neff(model, renormalize=False):
    model = copy.deepcopy(model)
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            mask = get_linear_mask(module).to(module.weight.device)
            with torch.no_grad():
                module.weight *= mask
                if renormalize:
                    row_sum = module.weight.sum(dim=0, keepdim=True).clamp(min=1e-8)
                    module.weight.div_(row_sum)
    return model

# 3. train a linear model first and storage

In [4]:
class LinearModel(nn.Module):
    def __init__(self, input_size, output_size, hidden_size= [512, 512, 512]):
        super(LinearModel, self).__init__()
        self.layers = nn.ModuleList()
        
        prev_size = input_size
        for size in hidden_size:
            self.layers.append(nn.Linear(prev_size, size))
            prev_size = size
            
        self.output = nn.Linear(prev_size, output_size)
        
    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten
        
        for layer in self.layers:
            x = F.relu(layer(x))        
        x = self.output(x)
        return F.log_softmax(x, dim=1)
        
def train(model, device, train_loader, optimizer, epoch):
    """Train for one epoch"""
    model.train()
    train_loss = 0
    correct = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    
    avg_loss = train_loss / len(train_loader)
    accuracy = 100. * correct / len(train_loader.dataset)
    return avg_loss, accuracy


def test(model, device, test_loader):
    """Evaluate model on test set"""
    model.eval()
    test_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    
    print(f'\nTest set: Average loss: {test_loss:.4f}, '
          f'Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)\n')
    
    return test_loss, accuracy

## data loader

In [5]:
batch_size = 64
test_batch_size = 1000
epochs = 10
lr = 3e-4

# MINIST-10 dataset
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)   
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LinearModel(input_size=28*28, output_size=10, hidden_size=[1024, 512, 512]).to(device)

optimizer = optim.Adam(model.parameters(), lr=lr)


In [6]:
result = {
    'train_loss': [],
    'train_accuracy': [],
    'test_loss': [],
    'test_accuracy': []
}
# initial test
test_loss, test_accuracy = test(model, device, test_loader)
result['test_loss'].append(test_loss)
result['test_accuracy'].append(test_accuracy)

# Training loop
for epoch in range(1, epochs + 1):
    train_loss, train_accuracy = train(model, device, train_loader, optimizer, epoch)
    result['train_loss'].append(train_loss)
    result['train_accuracy'].append(train_accuracy)
    
    # Test after each epoch
    test_loss, test_accuracy = test(model, device, test_loader)
    result['test_loss'].append(test_loss)
    result['test_accuracy'].append(test_accuracy)
    
# Save the model
torch.save(model.state_dict(), 'linear_model.pth')



Test set: Average loss: 2.3056, Accuracy: 1029/10000 (10.29%)


Test set: Average loss: 0.1143, Accuracy: 9639/10000 (96.39%)


Test set: Average loss: 0.0814, Accuracy: 9739/10000 (97.39%)


Test set: Average loss: 0.0894, Accuracy: 9734/10000 (97.34%)


Test set: Average loss: 0.0681, Accuracy: 9796/10000 (97.96%)


Test set: Average loss: 0.0838, Accuracy: 9775/10000 (97.75%)


Test set: Average loss: 0.0741, Accuracy: 9803/10000 (98.03%)


Test set: Average loss: 0.0736, Accuracy: 9815/10000 (98.15%)


Test set: Average loss: 0.0725, Accuracy: 9829/10000 (98.29%)


Test set: Average loss: 0.0801, Accuracy: 9799/10000 (97.99%)


Test set: Average loss: 0.0800, Accuracy: 9809/10000 (98.09%)



# prune the model and comparing the performance with the original model

In [11]:
pruned_model_renormalized = prune_model_neff(model, renormalize=True)
pruned_model_renormalized.to(device)

# Test the pruned model
test_loss, test_accuracy = test(pruned_model_renormalized, device, test_loader)
print(f'Pruned Model - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
# Save the pruned model
torch.save(pruned_model_renormalized.state_dict(), 'pruned_linear_model_renormalized.pth')


prune_model = prune_model_neff(model, renormalize=False)
prune_model.to(device)
# Test the pruned model without renormalization
test_loss, test_accuracy = test(prune_model, device, test_loader)
print(f'Pruned Model without Renormalization - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
# Save the pruned model without renormalization
torch.save(prune_model.state_dict(), 'pruned_linear_model.pth')


Test set: Average loss: 12157359320308445798030016249856.0000, Accuracy: 9469/10000 (94.69%)

Pruned Model - Test Loss: 12157359320308445798030016249856.0000, Test Accuracy: 94.69%

Test set: Average loss: 0.0744, Accuracy: 9810/10000 (98.10%)

Pruned Model without Renormalization - Test Loss: 0.0744, Test Accuracy: 98.10%


In [12]:
# test 10 times and show the average performance
test_loss_acc = {'prune_loss': [], 'prune_accuracy': [], 'prune_renorm_loss': [], 'prune_renorm_accuracy': []}

for i in range(10):
    pruned_model_renormalized = prune_model_neff(model, renormalize=True)
    pruned_model_renormalized.to(device)

    # Test the pruned model
    test_loss, test_accuracy = test(pruned_model_renormalized, device, test_loader)
    print(f'Pruned Model - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
    test_loss_acc['prune_renorm_loss'].append(test_loss)
    test_loss_acc['prune_renorm_accuracy'].append(test_accuracy)
    
    pruned_model = prune_model_neff(model, renormalize=False)
    pruned_model.to(device)
    # Test the pruned model without renormalization
    test_loss, test_accuracy = test(pruned_model, device, test_loader)
    print(f'Pruned Model without Renormalization - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
    test_loss_acc['prune_loss'].append(test_loss)
    test_loss_acc['prune_accuracy'].append(test_accuracy)
    
# average performance
avg_prune_loss = np.mean(test_loss_acc['prune_loss'])
avg_prune_accuracy = np.mean(test_loss_acc['prune_accuracy'])
avg_prune_renorm_loss = np.mean(test_loss_acc['prune_renorm_loss'])
avg_prune_renorm_accuracy = np.mean(test_loss_acc['prune_renorm_accuracy'])

print(f'Average Pruned Model - Test Loss: {avg_prune_loss:.4f}, Test Accuracy: {avg_prune_accuracy:.2f}%')
print(f'Average Pruned Model with Renormalization - Test Loss: {avg_prune_renorm_loss:.4f}, Test Accuracy: {avg_prune_renorm_accuracy:.2f}%')
    


Test set: Average loss: 12157359320308445798030016249856.0000, Accuracy: 9469/10000 (94.69%)

Pruned Model - Test Loss: 12157359320308445798030016249856.0000, Test Accuracy: 94.69%

Test set: Average loss: 0.0744, Accuracy: 9810/10000 (98.10%)

Pruned Model without Renormalization - Test Loss: 0.0744, Test Accuracy: 98.10%

Test set: Average loss: 12157359320308445798030016249856.0000, Accuracy: 9469/10000 (94.69%)

Pruned Model - Test Loss: 12157359320308445798030016249856.0000, Test Accuracy: 94.69%

Test set: Average loss: 0.0744, Accuracy: 9810/10000 (98.10%)

Pruned Model without Renormalization - Test Loss: 0.0744, Test Accuracy: 98.10%

Test set: Average loss: 12157359320308445798030016249856.0000, Accuracy: 9469/10000 (94.69%)

Pruned Model - Test Loss: 12157359320308445798030016249856.0000, Test Accuracy: 94.69%

Test set: Average loss: 0.0744, Accuracy: 9810/10000 (98.10%)

Pruned Model without Renormalization - Test Loss: 0.0744, Test Accuracy: 98.10%

Test set: Average los

# measure the sparisity of the model

In [10]:
def model_sparsity(model):
    """Calculate the sparsity of the model"""
    total_params = 0
    zero_params = 0
    
    for name, param in model.named_parameters():
        if 'weight' in name:
            total_params += param.numel()
            zero_params += torch.sum(param == 0).item()
    
    sparsity = zero_params / total_params
    return sparsity

In [None]:
sparsity = model_sparsity(pruned_model_renormalized)
print(f'Sparsity of the pruned model with renormalization: {sparsity:.4f}')
sparsity = model_sparsity(prune_model)
print(f'Sparsity of the pruned model without renormalization: {sparsity:.4f}')
sparsity = model_sparsity(model)
print(f'Sparsity of the original model: {sparsity:.4f}')

Sparsity of the pruned model with renormalization: 0.3419
Sparsity of the pruned model without renormalization: 0.3419
Sparsity of the original model: 0.0000


# test in huggingface language model

In [14]:
def compute_accuracy(model, tokenized_dataset, batch_size=32):
    from torch.utils.data import DataLoader
    model.eval()
    device = next(model.parameters()).device
    correct = 0
    total = 0
    loader = DataLoader(tokenized_dataset, batch_size=batch_size)
    with torch.no_grad():
        for batch in loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k in ['input_ids', 'attention_mask']}
            labels = batch['labels'].to(device)
            outputs = model(**inputs)
            preds = outputs.logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += len(labels)
    return correct / total

In [16]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
raw_dataset = load_dataset("ag_news")
train_dataset = raw_dataset["train"]
test_dataset = raw_dataset["test"]
def tokenize_function(example):
    return tokenizer(
        example["text"],
        truncation=True,
        padding="max_length",
        max_length=128
    )

# Tokenize and format
tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True)
tokenized_test_dataset = test_dataset.map(tokenize_function, batched=True)
tokenized_train_dataset = tokenized_train_dataset.rename_column("label", "labels")
tokenized_test_dataset = tokenized_test_dataset.rename_column("label", "labels")
tokenized_train_dataset.set_format(type="torch", columns=['input_ids', 'attention_mask', 'labels'])
tokenized_test_dataset.set_format(type="torch", columns=['input_ids', 'attention_mask', 'labels'])

# 4. Load and fine-tune BERT on train split
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4).to(device)
finetune_args = TrainingArguments(
    output_dir="./tmp_finetuned_bert",
    per_device_train_batch_size=16,
    num_train_epochs=5,
    logging_steps=1000,
    save_strategy="no",
    report_to=[]
)
trainer = Trainer(
    model=model,
    args=finetune_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_test_dataset,
    tokenizer=tokenizer,
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


In [17]:
print("\n=== Training Original Model ===")
trainer.train()

orig_acc = compute_accuracy(model, tokenized_test_dataset)
sparsity = model_sparsity(model)
print(f"\nFine-tuned original model: Sparsity={sparsity:.4f}, Acc={orig_acc:.4f}")

torch.save(model.state_dict(), 'bert_origin.pth')


=== Training Original Model ===


Step,Training Loss
1000,0.339
2000,0.2639
3000,0.2671
4000,0.2409
5000,0.2293
6000,0.2223
7000,0.2154
8000,0.1935
9000,0.1679
10000,0.1688



Fine-tuned original model: Sparsity=0.0000, Acc=0.9439


In [18]:
pruned_model_renormalized = prune_model_neff(model, renormalize=True)
pruned_model_renormalized.to(device)
pruned_model = prune_model_neff(model, renormalize=False)
pruned_model.to(device)

prune_acc = compute_accuracy(pruned_model, tokenized_test_dataset)
prune_renorm_acc = compute_accuracy(pruned_model_renormalized, tokenized_test_dataset)
pruned_sparsity = model_sparsity(pruned_model)
pruned_renorm_sparsity = model_sparsity(pruned_model_renormalized)

print(f"\nPruned model: Sparsity={pruned_sparsity:.4f}, Acc={prune_acc:.4f}")
print(f"Pruned model with renormalization: Sparsity={pruned_renorm_sparsity:.4f}, Acc={prune_renorm_acc:.4f}")


Pruned model: Sparsity=0.2923, Acc=0.9451
Pruned model with renormalization: Sparsity=0.2923, Acc=0.2500


In [19]:
torch.save(pruned_model.state_dict(), 'bert_pruned.pth')
torch.save(pruned_model_renormalized.state_dict(), 'bert_pruned_renorm.pth')

## test for LLM

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from datasets import load_dataset
import torch
from tqdm import tqdm

# Configuration
MODEL_NAME = "Qwen/Qwen-7B"  # Use "Qwen/Qwen-7B" for smaller variant
DATASET_NAME = "wikitext"
DATASET_CONFIG = "wikitext-2-raw-v1"
DEVICE_MAP = "auto"  # Automatically distributes across GPUs
BATCH_SIZE = 1  # Reduce if OOM errors occur

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

# Load dataset
test_dataset = load_dataset(DATASET_NAME, DATASET_CONFIG, split="test")
texts = [text for text in test_dataset["text"] if text.strip()]  # Remove empty strings

In [4]:
# Load model with quantization (4-bit) to reduce VRAM usage
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map=DEVICE_MAP,
    torch_dtype=torch.float16,
    quantization_config={"load_in_4bit": True},
    trust_remote_code=True
)

# Calculate perplexity
model.eval()
total_log_likelihood = 0
total_tokens = 0

with torch.no_grad():
    for text in tqdm(texts, desc="Calculating Perplexity"):
        # Tokenize text
        inputs = tokenizer(text, return_tensors="pt", truncation=True).to(model.device)
        
        # Forward pass to get loss
        outputs = model(**inputs, labels=inputs["input_ids"])
        loss = outputs.loss.item()
        
        # Update metrics
        total_log_likelihood += loss * inputs["input_ids"].size(1)
        total_tokens += inputs["input_ids"].size(1)

# Final perplexity calculation
perplexity = torch.exp(torch.tensor(total_log_likelihood / total_tokens)).item()
print(f"Perplexity: {perplexity:.2f}")

torch.save(model.state_dict(), 'qwen_model.pth')

The model is automatically converting to bf16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".
Try importing flash-attention for faster inference...


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Calculating Perplexity: 100%|██████████| 2891/2891 [03:12<00:00, 15.00it/s]


Perplexity: 17.44


In [5]:
def prune_model_neff_llm(model, renormalize=False):
    """
    Prune LLM model by targeting only standard Linear layers
    Avoids quantized layers and special layer types
    """
    model = copy.deepcopy(model)
    pruned_layers = []
    
    for name, module in model.named_modules():
        # Only prune standard nn.Linear layers, avoid quantized layers
        if isinstance(module, nn.Linear) and not hasattr(module, 'quant_state'):
            try:
                mask = get_linear_mask(module).to(module.weight.device)
                with torch.no_grad():
                    module.weight *= mask.float()
                    
                    if renormalize:
                        # More stable renormalization
                        row_sum = module.weight.abs().sum(dim=1, keepdim=True).clamp(min=1e-8)
                        module.weight.div_(row_sum)
                    
                    pruned_layers.append(name)
                    
                    # Check sparsity of this layer
                    sparsity = (module.weight == 0).float().mean().item()
                    print(f"Pruned {name}: {sparsity:.2%} sparsity")
                    
            except Exception as e:
                print(f"Skipping {name}: {e}")
                continue
    
    print(f"Successfully pruned {len(pruned_layers)} layers")
    return model

In [6]:
pruned_Qwen = prune_model_neff_llm(model, renormalize=False)
pruned_Qwen.to("cuda" if torch.cuda.is_available() else "cpu")

with torch.no_grad():
    for text in tqdm(texts, desc="Calculating Perplexity for Pruned Model"):
        inputs = tokenizer(text, return_tensors="pt", truncation=True).to(pruned_Qwen.device)
        outputs = pruned_Qwen(**inputs, labels=inputs["input_ids"])
        loss = outputs.loss.item()
        
        total_log_likelihood += loss * inputs["input_ids"].size(1)
        total_tokens += inputs["input_ids"].size(1)
        
perplexity_pruned = torch.exp(torch.tensor(total_log_likelihood / total_tokens)).item()
print(f"Perplexity of Pruned Model: {perplexity_pruned:.2f}")

Pruned lm_head: 39.97% sparsity
Successfully pruned 1 layers


Calculating Perplexity for Pruned Model: 100%|██████████| 2891/2891 [04:20<00:00, 11.10it/s]

Perplexity of Pruned Model: 17.98





In [7]:
model

QWenLMHeadModel(
  (transformer): QWenModel(
    (wte): Embedding(151936, 4096)
    (drop): Dropout(p=0.0, inplace=False)
    (rotary_emb): RotaryEmbedding()
    (h): ModuleList(
      (0-31): 32 x QWenBlock(
        (ln_1): RMSNorm()
        (attn): QWenAttention(
          (c_attn): Linear4bit(in_features=4096, out_features=12288, bias=True)
          (c_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): RMSNorm()
        (mlp): QWenMLP(
          (w1): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (w2): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (c_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
        )
      )
    )
    (ln_f): RMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=151936, bias=False)
)

# test test

In [None]:


# ---- 2. Storage and nonzero counting ----
def get_model_size(model, tmp_file="tmp_model.bin"):
    torch.save(model.state_dict(), tmp_file)
    size_mb = os.path.getsize(tmp_file) / (1024 * 1024)
    os.remove(tmp_file)
    return size_mb

def count_nonzero_params(model):
    nonzero = 0
    total = 0
    for p in model.parameters():
        total += p.numel()
        nonzero += (p != 0).sum().item()
    return total, nonzero

def get_folder_size_mb(folder):
    total_size = 0
    for dirpath, dirnames, filenames in os.walk(folder):
        for f in filenames:
            fp = os.path.join(dirpath, f)
            total_size += os.path.getsize(fp)
    return total_size / (1024 * 1024)

# ---- 3. Sparse export/load ----
def export_model_sparse(model, out_dir="sparse_export"):
    os.makedirs(out_dir, exist_ok=True)
    meta = {}
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            w = module.weight.data.cpu().numpy()
            nonzero = np.nonzero(w)
            values = w[nonzero]
            indices = np.vstack(nonzero).T
            np.save(os.path.join(out_dir, f"{name}_values.npy"), values)
            np.save(os.path.join(out_dir, f"{name}_indices.npy"), indices)
            meta[name] = {"shape": w.shape, "n_nonzero": len(values)}
            if module.bias is not None:
                np.save(os.path.join(out_dir, f"{name}_bias.npy"), module.bias.data.cpu().numpy())
    with open(os.path.join(out_dir, "meta.json"), "w") as f:
        json.dump(meta, f)
    print(f"Exported sparse weights to {out_dir}")

def load_model_sparse(model, sparse_dir="sparse_export"):
    with open(os.path.join(sparse_dir, "meta.json")) as f:
        meta = json.load(f)
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and name in meta:
            shape = tuple(meta[name]["shape"])
            w = np.zeros(shape, dtype=np.float32)
            values = np.load(os.path.join(sparse_dir, f"{name}_values.npy"))
            indices = np.load(os.path.join(sparse_dir, f"{name}_indices.npy"))
            w[indices[:,0], indices[:,1]] = values
            module.weight.data = torch.tensor(w, dtype=module.weight.data.dtype, device=module.weight.data.device)
            bias_path = os.path.join(sparse_dir, f"{name}_bias.npy")
            if os.path.exists(bias_path):
                module.bias.data = torch.tensor(np.load(bias_path), dtype=module.bias.data.dtype, device=module.bias.data.device)
    print(f"Loaded sparse weights from {sparse_dir}")
    return model

# ---- 4. Accuracy function ----
def compute_accuracy(model, tokenized_dataset, batch_size=32):
    from torch.utils.data import DataLoader
    model.eval()
    device = next(model.parameters()).device
    correct = 0
    total = 0
    loader = DataLoader(tokenized_dataset, batch_size=batch_size)
    with torch.no_grad():
        for batch in loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k in ['input_ids', 'attention_mask']}
            labels = batch['labels'].to(device)
            outputs = model(**inputs)
            preds = outputs.logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += len(labels)
    return correct / total

# ---- 5. Tokenization and data setup ----
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
raw_dataset = load_dataset("ag_news")
train_dataset = raw_dataset["train"]
test_dataset = raw_dataset["test"]

def tokenize_function(example):
    return tokenizer(
        example["text"],
        truncation=True,
        padding="max_length",
        max_length=128
    )

tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True)
tokenized_test_dataset = test_dataset.map(tokenize_function, batched=True)
tokenized_train_dataset = tokenized_train_dataset.rename_column("label", "labels")
tokenized_test_dataset = tokenized_test_dataset.rename_column("label", "labels")
tokenized_train_dataset.set_format(type="torch", columns=['input_ids', 'attention_mask', 'labels'])
tokenized_test_dataset.set_format(type="torch", columns=['input_ids', 'attention_mask', 'labels'])

# ---- 6. Train original model ----
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4).to(device)
finetune_args = TrainingArguments(
    output_dir="./tmp_finetuned_bert",
    per_device_train_batch_size=16,
    num_train_epochs=5,
    logging_steps=100,
    save_strategy="no",
    report_to=[]
)
trainer = Trainer(
    model=model,
    args=finetune_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_test_dataset,
    tokenizer=tokenizer,
)
print("\n=== Training Original Model ===")
trainer.train()

# ---- 7. Evaluate and analyze original ----
orig_acc = compute_accuracy(model, tokenized_test_dataset)
orig_size = get_model_size(model)
orig_total, orig_nonzero = count_nonzero_params(model)
print(f"\n original model: Size={orig_size:.2f} MB, Acc={orig_acc:.4f}, Nonzeros={orig_nonzero}/{orig_total}")

torch.save(model.state_dict(), "dense_model.pt")

# ---- 8. Prune, save, and analyze pruned model ----
pruned_model = prune_model_neff(model, renormalize=False).to(device)
pruned_acc = compute_accuracy(pruned_model, tokenized_test_dataset)
pruned_size = get_model_size(pruned_model)
pruned_total, pruned_nonzero = count_nonzero_params(pruned_model)
print(f"Pruned model:             Size={pruned_size:.2f} MB, Acc={pruned_acc:.4f}, Nonzeros={pruned_nonzero}/{pruned_total}")

torch.save(pruned_model.state_dict(), "dense_pruned_model.pt")
export_model_sparse(pruned_model, out_dir="sparse_export")
sparse_disk_size = get_folder_size_mb("sparse_export")
print(f"Sparse export folder size: {sparse_disk_size:.2f} MB")

# ---- 10. Load sparse weights into a new model and compare ----
reloaded_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4).to(device)
reloaded_model = load_model_sparse(reloaded_model, sparse_dir="sparse_export")
reload_acc = compute_accuracy(reloaded_model, tokenized_test_dataset)
reload_total, reload_nonzero = count_nonzero_params(reloaded_model)
print(f"Reloaded sparse model:    Acc={reload_acc:.4f}, Nonzeros={reload_nonzero}/{reload_total}")

finetune_args_pruned = TrainingArguments(
    output_dir="./tmp_pruned_bert",
    per_device_train_batch_size=16,
    num_train_epochs=1,
    logging_steps=100,
    save_strategy="no",
    report_to=[]
)
trainer_pruned = Trainer(
    model=reloaded_model,
    args=finetune_args_pruned,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_test_dataset,
    tokenizer=tokenizer,
)
print("\n=== Fine-tuning Pruned Model ===")
trainer_pruned.train()
finetuned_acc = compute_accuracy(reloaded_model, tokenized_test_dataset)
finetuned_size = get_model_size(reloaded_model)
finetuned_total, finetuned_nonzero = count_nonzero_params(reloaded_model)
print(f"Fine-tuned pruned model:  Size={finetuned_size:.2f} MB, Acc={finetuned_acc:.4f}, Nonzeros={finetuned_nonzero}/{finetuned_total}")

# ---- 11. Dense vs. Sparse Storage ----
import os
dense_size = os.path.getsize("dense_pruned_model.pt") / (1024*1024)
print(f"Dense .pt pruned model size: {dense_size:.2f} MB")
print(f"Sparse export folder size:   {sparse_disk_size:.2f} MB")

print("\n=== Summary ===")
print(f"Original   : {orig_size:.2f} MB, Acc={orig_acc:.4f}, Nonzeros={orig_nonzero}")
print(f"Pruned     : {pruned_size:.2f} MB, Acc={pruned_acc:.4f}, Nonzeros={pruned_nonzero}")
print(f"Finetuned  : {finetuned_size:.2f} MB, Acc={finetuned_acc:.4f}, Nonzeros={finetuned_nonzero}")
print(f"Sparse export folder size:   {sparse_disk_size:.2f} MB")
print(f"Dense .pt pruned model size: {dense_size:.2f} MB")
print(f"Reloaded   : Acc={reload_acc:.4f}, Nonzeros={reload_nonzero}")


## test2

In [None]:
import torch
import torch.nn as nn
import copy
from transformers import BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments, TrainerCallback
from datasets import load_dataset
import os
import numpy as np
import json
import scipy.sparse as sp
import pickle

# Custom callback to print loss
class LossLoggingCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None and 'loss' in logs:
            print(f"\nStep {state.global_step}: Loss = {logs['loss']:.4f}")

# ---- 1. N_eff-based mask/prune functions ----
def get_neff_mask_linear(module):
    w = module.weight.data
    out_features, in_features = w.shape
    w_abs = torch.abs(w)
    norm_factor = w_abs.sum(dim=1, keepdim=True).clamp(min=1e-8)
    w_norm = w_abs / norm_factor
    w_norm_sum_sq = (w_norm**2).sum(dim=1)
    neff = torch.clamp(torch.floor(1.0 / w_norm_sum_sq), min=1).long()
    k_max = neff.max().item()
    topk_vals, _ = torch.topk(w_norm, k=k_max, dim=1, sorted=False)
    thresholds = topk_vals[torch.arange(out_features), neff-1].unsqueeze(1)
    mask = (w_norm >= thresholds).float()
    return mask

def prune_model_neff(model, renormalize=False):
    model = copy.deepcopy(model)
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            mask = get_neff_mask_linear(module).to(module.weight.device)
            with torch.no_grad():
                module.weight *= mask
                if renormalize:
                    row_sum = module.weight.abs().sum(dim=1, keepdim=True).clamp(min=1e-8)
                    module.weight.div_(row_sum)
    return model

# ---- 2. Storage and nonzero counting ----
def get_model_size(model, tmp_file="tmp_model.bin"):
    torch.save(model.state_dict(), tmp_file)
    size_mb = os.path.getsize(tmp_file) / (1024 * 1024)
    os.remove(tmp_file)
    return size_mb

def count_nonzero_params(model):
    nonzero = 0
    total = 0
    for p in model.parameters():
        total += p.numel()
        nonzero += (p != 0).sum().item()
    return total, nonzero

def get_folder_size_mb(folder):
    total_size = 0
    for dirpath, dirnames, filenames in os.walk(folder):
        for f in filenames:
            fp = os.path.join(dirpath, f)
            total_size += os.path.getsize(fp)
    return total_size / (1024 * 1024)

# ---- 3. Efficient Sparse export/load ----
def export_model_sparse_efficient(model, out_dir="sparse_export"):
    """Export pruned model in sparse format for efficient storage"""
    os.makedirs(out_dir, exist_ok=True)
    
    state_dict = model.state_dict()
    sparse_data = {
        'sparse_weights': {},
        'dense_params': {},
        'metadata': {}
    }
    
    total_params = 0
    sparse_params = 0
    
    for name, param in state_dict.items():
        param_numpy = param.cpu().numpy()
        
        # Only make 2D weight matrices sparse
        if 'weight' in name and len(param.shape) == 2:
            # Convert to CSR sparse matrix
            sparse_matrix = sp.csr_matrix(param_numpy)
            sparse_data['sparse_weights'][name] = {
                'data': sparse_matrix.data.astype(np.float16),  # Use float16 for compression
                'indices': sparse_matrix.indices.astype(np.int32),  # Use int32 instead of int64
                'indptr': sparse_matrix.indptr.astype(np.int32),
                'shape': sparse_matrix.shape
            }
            
            nnz = sparse_matrix.nnz
            total = param_numpy.shape[0] * param_numpy.shape[1]
            sparsity = 1 - (nnz / total)
            
            sparse_data['metadata'][name] = {
                'shape': list(param.shape),
                'nnz': nnz,
                'total': total,
                'sparsity': sparsity,
                'dtype': str(param.dtype)
            }
            
            total_params += total
            sparse_params += nnz
            
            print(f"  {name}: {nnz}/{total} ({sparsity:.1%} sparse)")
        else:
            # Keep other parameters dense
            sparse_data['dense_params'][name] = param_numpy
    
    # Save everything in one compressed file
    with open(os.path.join(out_dir, "model_sparse.pkl"), "wb") as f:
        pickle.dump(sparse_data, f, protocol=pickle.HIGHEST_PROTOCOL)
    
    overall_sparsity = 1 - (sparse_params / total_params)
    print(f"\nOverall sparsity: {overall_sparsity:.1%}")
    print(f"Exported to {out_dir}")
    
    return sparse_data['metadata']

def load_model_sparse_efficient(model, sparse_dir="sparse_export"):
    """Load sparse model back"""
    sparse_path = os.path.join(sparse_dir, "model_sparse.pkl")
    
    with open(sparse_path, "rb") as f:
        sparse_data = pickle.load(f)
    
    state_dict = model.state_dict()
    new_state_dict = {}
    
    # Reconstruct sparse weights
    for name in state_dict:
        if name in sparse_data['sparse_weights']:
            sw = sparse_data['sparse_weights'][name]
            # First convert data back to float32 before creating sparse matrix
            data_float32 = sw['data'].astype(np.float32)
            # Reconstruct CSR matrix with float32 data
            sparse_matrix = sp.csr_matrix(
                (data_float32, sw['indices'], sw['indptr']), 
                shape=sw['shape'],
                dtype=np.float32
            )
            # Convert back to dense tensor
            dense_array = sparse_matrix.toarray()
            new_state_dict[name] = torch.from_numpy(dense_array).to(state_dict[name].dtype)
        elif name in sparse_data['dense_params']:
            new_state_dict[name] = torch.from_numpy(
                sparse_data['dense_params'][name]
            ).to(state_dict[name].dtype)
        else:
            # Keep original if not found
            new_state_dict[name] = state_dict[name]
    
    model.load_state_dict(new_state_dict)
    print(f"Loaded sparse model from {sparse_dir}")
    return model

# ---- 4. Accuracy function ----
def compute_accuracy(model, tokenized_dataset, batch_size=32):
    from torch.utils.data import DataLoader
    model.eval()
    device = next(model.parameters()).device
    correct = 0
    total = 0
    loader = DataLoader(tokenized_dataset, batch_size=batch_size)
    with torch.no_grad():
        for batch in loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k in ['input_ids', 'attention_mask']}
            labels = batch['labels'].to(device)
            outputs = model(**inputs)
            preds = outputs.logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += len(labels)
    return correct / total

# ---- Main execution ----
if __name__ == "__main__":
    # Setup
    device = "cuda" if torch.cuda.is_available() else "cpu"
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    raw_dataset = load_dataset("ag_news")
    train_dataset = raw_dataset["train"]
    test_dataset = raw_dataset["test"]

    def tokenize_function(example):
        return tokenizer(
            example["text"],
            truncation=True,
            padding="max_length",
            max_length=128
        )

    tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True)
    tokenized_test_dataset = test_dataset.map(tokenize_function, batched=True)
    tokenized_train_dataset = tokenized_train_dataset.rename_column("label", "labels")
    tokenized_test_dataset = tokenized_test_dataset.rename_column("label", "labels")
    tokenized_train_dataset.set_format(type="torch", columns=['input_ids', 'attention_mask', 'labels'])
    tokenized_test_dataset.set_format(type="torch", columns=['input_ids', 'attention_mask', 'labels'])

    # Train original model
    print("=== Phase 1: Training Original Model ===")
    model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4).to(device)
    finetune_args = TrainingArguments(
        output_dir="./tmp_finetuned_bert",
        per_device_train_batch_size=16,
        num_train_epochs=5,
        logging_steps=1000,  # Log every 1000 steps
        logging_dir='./logs',  # TensorBoard logs
        save_strategy="no",
        report_to=["tensorboard"],  # Enable TensorBoard logging
        log_level="info",  # Ensure info level logging
        disable_tqdm=False,  # Keep progress bar
        evaluation_strategy="steps",  # Evaluate during training
        eval_steps=1000,  # Evaluate every 1000 steps
    )
    trainer = Trainer(
        model=model,
        args=finetune_args,
        train_dataset=tokenized_train_dataset,
        eval_dataset=tokenized_test_dataset,
        tokenizer=tokenizer,
        callbacks=[LossLoggingCallback()],  # Add custom callback
    )
    trainer.train()

    # Evaluate original
    orig_acc = compute_accuracy(model, tokenized_test_dataset)
    orig_size = get_model_size(model)
    orig_total, orig_nonzero = count_nonzero_params(model)
    print(f"\n✓ Original model: {orig_size:.2f} MB, Accuracy={orig_acc:.4f}")

    # Prune model
    print("\n=== Phase 2: Pruning Model ===")
    pruned_model = prune_model_neff(model, renormalize=False).to(device)
    pruned_acc = compute_accuracy(pruned_model, tokenized_test_dataset)
    pruned_total, pruned_nonzero = count_nonzero_params(pruned_model)
    sparsity = 1 - (pruned_nonzero / pruned_total)
    print(f"✓ Pruned model: Accuracy={pruned_acc:.4f} (dropped {orig_acc-pruned_acc:.4f})")
    print(f"  Sparsity: {sparsity:.1%} ({pruned_total-pruned_nonzero:,} zeros)")

    # Export sparse
    print("\n=== Phase 3: Saving Sparse Model ===")
    metadata = export_model_sparse_efficient(pruned_model, out_dir="sparse_export")
    sparse_disk_size = get_folder_size_mb("sparse_export")
    compression_ratio = orig_size / sparse_disk_size
    print(f"✓ Sparse storage: {sparse_disk_size:.2f} MB (compression ratio: {compression_ratio:.1f}x)")

    # Compare with dense storage
    torch.save(pruned_model.state_dict(), "pruned_dense.pt")
    dense_pruned_size = os.path.getsize("pruned_dense.pt") / (1024 * 1024)
    print(f"  Dense storage would be: {dense_pruned_size:.2f} MB")
    os.remove("pruned_dense.pt")

    # Load sparse model
    print("\n=== Phase 4: Loading Sparse Model ===")
    reloaded_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4).to(device)
    reloaded_model = load_model_sparse_efficient(reloaded_model, sparse_dir="sparse_export")
    reload_acc = compute_accuracy(reloaded_model, tokenized_test_dataset)
    reload_total, reload_nonzero = count_nonzero_params(reloaded_model)
    
    # Check if loading worked correctly
    acc_diff = abs(pruned_acc - reload_acc)
    if acc_diff > 0.001:
        print(f"⚠️  WARNING: Accuracy changed by {acc_diff:.4f} after reload!")
    else:
        print(f"✓ Loaded successfully: Accuracy={reload_acc:.4f} (preserved)")

    # ONE-SHOT RECOVERY TRAINING
    print("\n=== Phase 5: ONE-SHOT Recovery Training ===")
    print("This is the key innovation: recovering accuracy with just 1 epoch!")
    
    oneshot_args = TrainingArguments(
        output_dir="./tmp_oneshot",
        per_device_train_batch_size=16,
        num_train_epochs=1,  # Just ONE epoch!
        logging_steps=1000,  # Log every 1000 steps
        logging_dir='./logs_oneshot',
        save_strategy="no",
        report_to=["tensorboard"],
        log_level="info",
        disable_tqdm=False,
        evaluation_strategy="steps",
        eval_steps=1000,
    )
    oneshot_trainer = Trainer(
        model=reloaded_model,
        args=oneshot_args,
        train_dataset=tokenized_train_dataset,
        eval_dataset=tokenized_test_dataset,
        tokenizer=tokenizer,
        callbacks=[LossLoggingCallback()],  # Add custom callback
    )
    oneshot_trainer.train()
    
    # Evaluate after one-shot training
    oneshot_acc = compute_accuracy(reloaded_model, tokenized_test_dataset)
    oneshot_total, oneshot_nonzero = count_nonzero_params(reloaded_model)
    
    print(f"\n✓ After ONE-SHOT training: Accuracy={oneshot_acc:.4f}")
    print(f"  Accuracy recovered: {oneshot_acc - reload_acc:.4f} → almost back to {orig_acc:.4f}!")
    print(f"  Nonzero params: {oneshot_nonzero:,} (zeros filled back during training)")

    # Final summary
    print("\n" + "="*60)
    print("SUMMARY: Your One-Shot Recovery Technique")
    print("="*60)
    print(f"1. Original model    : {orig_size:.2f} MB, Acc={orig_acc:.4f}")
    print(f"2. After pruning     : Acc={pruned_acc:.4f} (↓{orig_acc-pruned_acc:.4f}), {sparsity:.1%} sparse")
    print(f"3. Sparse storage    : {sparse_disk_size:.2f} MB ({compression_ratio:.1f}x smaller)")
    print(f"4. After loading     : Acc={reload_acc:.4f}")
    print(f"5. ONE-SHOT recovery : Acc={oneshot_acc:.4f} (↑{oneshot_acc-reload_acc:.4f})")
    print(f"\n✨ Key insight: Just 1 epoch recovers {(oneshot_acc/orig_acc)*100:.1f}% of original accuracy!")
    print(f"   Even though zeros get filled back, the sparse storage + one-shot")
    print(f"   training gives you a compressed model delivery pipeline!")

## test3


In [None]:
import torch
import torch.nn as nn
import copy
from transformers import BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import os
import numpy as np
import json
import zipfile
import io

# ---- 1. Improved N_eff-based pruning ----
def get_neff_mask_linear(module):
    w = module.weight.data
    out_features, in_features = w.shape
    w_abs = torch.abs(w)
    norm_factor = w_abs.sum(dim=1, keepdim=True).clamp(min=1e-8)
    w_norm = w_abs / norm_factor
    w_norm_sum_sq = (w_norm**2).sum(dim=1)
    neff = torch.clamp(torch.floor(1.0 / w_norm_sum_sq), min=1).long()
    k_max = neff.max().item()
    
    # Efficient top-k using torch.topk
    topk_vals, _ = torch.topk(w_norm, k=k_max, dim=1, sorted=True)
    thresholds = topk_vals[torch.arange(out_features), neff-1].unsqueeze(1)
    mask = (w_norm >= thresholds).float()
    return mask

def prune_model_neff(model, renormalize=True):  # Changed to True for better stability
    model = copy.deepcopy(model)
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            mask = get_neff_mask_linear(module).to(module.weight.device)
            with torch.no_grad():
                module.weight.mul_(mask)
                if renormalize:
                    # L1 renormalization
                    row_sum = module.weight.abs().sum(dim=1, keepdim=True).clamp(min=1e-8)
                    module.weight.div_(row_sum)
    return model

# ---- 2. Storage improvements ----
def get_model_size(model, tmp_file="tmp_model.bin"):
    torch.save(model.state_dict(), tmp_file)
    size_mb = os.path.getsize(tmp_file) / (1024 * 1024)
    os.remove(tmp_file)
    return size_mb

def count_nonzero_params(model):
    nonzero = 0
    total = 0
    for p in model.parameters():
        total += p.numel()
        nonzero += (p != 0).sum().item()
    return total, nonzero

def get_folder_size_mb(folder):
    total_size = 0
    for dirpath, dirnames, filenames in os.walk(folder):
        for f in filenames:
            fp = os.path.join(dirpath, f)
            total_size += os.path.getsize(fp)
    return total_size / (1024 * 1024)

# ---- 3. Fixed sparse export/load ----
def export_model_sparse(model, out_dir="sparse_export"):
    os.makedirs(out_dir, exist_ok=True)
    meta = {}
    state_dict = model.state_dict()
    
    # Save all non-linear parameters normally
    non_linear_state_dict = {}
    linear_keys = set()
    for module_name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            linear_keys.add(f"{module_name}.weight")
            if module.bias is not None:
                linear_keys.add(f"{module_name}.bias")
    
    for k, v in state_dict.items():
        if k not in linear_keys:
            non_linear_state_dict[k] = v.cpu()
    
    torch.save(non_linear_state_dict, os.path.join(out_dir, "non_linear_state_dict.pt"))
    
    # Save linear layers in sparse format
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            w = module.weight.data.cpu().numpy()
            nonzero = w != 0
            values = w[nonzero]
            
            # Efficient sparse storage: CSR format
            sparse_format = {
                "data": values,
                "indices": np.where(nonzero)[1],  # column indices
                "indptr": np.concatenate([[0], np.cumsum(nonzero.sum(axis=1))])  # row pointers
            }
            
            np.savez_compressed(
                os.path.join(out_dir, f"{name}.npz"),
                **sparse_format
            )
            meta[name] = {
                "shape": w.shape,
                "dtype": str(w.dtype),
                "nnz": len(values)
            }
            
            if module.bias is not None:
                np.save(
                    os.path.join(out_dir, f"{name}_bias.npy"),
                    module.bias.data.cpu().numpy()
                )
    
    with open(os.path.join(out_dir, "meta.json"), "w") as f:
        json.dump(meta, f)
    
    print(f"Exported sparse weights to {out_dir}")
    return get_folder_size_mb(out_dir)

def load_model_sparse(model, sparse_dir="sparse_export"):
    # Load non-linear parameters
    non_linear_path = os.path.join(sparse_dir, "non_linear_state_dict.pt")
    if os.path.exists(non_linear_path):
        non_linear_state_dict = torch.load(non_linear_path)
        model.load_state_dict(non_linear_state_dict, strict=False)
    
    # Load meta data
    with open(os.path.join(sparse_dir, "meta.json")) as f:
        meta = json.load(f)
    
    # Load sparse linear layers
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and name in meta:
            sparse_path = os.path.join(sparse_dir, f"{name}.npz")
            if not os.path.exists(sparse_path):
                continue
                
            sparse_data = np.load(sparse_path)
            shape = tuple(meta[name]["shape"])
            w = np.zeros(shape, dtype=np.float32)
            
            # Reconstruct from CSR format
            rows = []
            for i in range(len(sparse_data["indptr"]) - 1):
                start = sparse_data["indptr"][i]
                end = sparse_data["indptr"][i+1]
                cols = sparse_data["indices"][start:end]
                w[i, cols] = sparse_data["data"][start:end]
            
            module.weight.data = torch.tensor(w, device=module.weight.data.device)
            
            bias_path = os.path.join(sparse_dir, f"{name}_bias.npy")
            if os.path.exists(bias_path):
                module.bias.data = torch.tensor(
                    np.load(bias_path),
                    device=module.bias.data.device
                )
    
    print(f"Loaded sparse weights from {sparse_dir}")
    return model

# ---- 4. Accuracy function ----
def compute_accuracy(model, tokenized_dataset, batch_size=32):
    from torch.utils.data import DataLoader
    model.eval()
    device = next(model.parameters()).device
    correct = 0
    total = 0
    loader = DataLoader(tokenized_dataset, batch_size=batch_size)
    with torch.no_grad():
        for batch in loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k in ['input_ids', 'attention_mask']}
            labels = batch['labels'].to(device)
            outputs = model(**inputs)
            preds = outputs.logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += len(labels)
    return correct / total

# ---- 5. Tokenization and data setup ----
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
raw_dataset = load_dataset("ag_news")
train_dataset = raw_dataset["train"]
test_dataset = raw_dataset["test"]

def tokenize_function(example):
    return tokenizer(
        example["text"],
        truncation=True,
        padding="max_length",
        max_length=128
    )

tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True)
tokenized_test_dataset = test_dataset.map(tokenize_function, batched=True)
tokenized_train_dataset = tokenized_train_dataset.rename_column("label", "labels")
tokenized_test_dataset = tokenized_test_dataset.rename_column("label", "labels")
tokenized_train_dataset.set_format(type="torch", columns=['input_ids', 'attention_mask', 'labels'])
tokenized_test_dataset.set_format(type="torch", columns=['input_ids', 'attention_mask', 'labels'])

# ---- 6. Train original model ----
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4).to(device)
finetune_args = TrainingArguments(
    output_dir="./tmp_finetuned_bert",
    per_device_train_batch_size=16,
    num_train_epochs=3,  # Reduced for faster testing
    logging_steps=100,
    save_strategy="no",
    report_to=[]
)
trainer = Trainer(
    model=model,
    args=finetune_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_test_dataset,
    tokenizer=tokenizer,
)
print("\n=== Training Original Model ===")
trainer.train()

# ---- 7. Evaluate and analyze original ----
orig_acc = compute_accuracy(model, tokenized_test_dataset)
orig_size = get_model_size(model)
orig_total, orig_nonzero = count_nonzero_params(model)
print(f"\n original model: Size={orig_size:.2f} MB, Acc={orig_acc:.4f}, Nonzeros={orig_nonzero}/{orig_total}")

torch.save(model.state_dict(), "dense_model.pt")

# ---- 8. Prune, save, and analyze pruned model ----
pruned_model = prune_model_neff(model, renormalize=True).to(device)
pruned_acc = compute_accuracy(pruned_model, tokenized_test_dataset)
pruned_size = get_model_size(pruned_model)
pruned_total, pruned_nonzero = count_nonzero_params(pruned_model)
print(f"Pruned model:             Size={pruned_size:.2f} MB, Acc={pruned_acc:.4f}, Nonzeros={pruned_nonzero}/{pruned_total}")

torch.save(pruned_model.state_dict(), "dense_pruned_model.pt")
sparse_disk_size = export_model_sparse(pruned_model, out_dir="sparse_export")
print(f"Sparse export folder size: {sparse_disk_size:.2f} MB")

# ---- 9. Load sparse weights correctly ----
reloaded_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4).to(device)
reloaded_model = load_model_sparse(reloaded_model, sparse_dir="sparse_export")
reload_acc = compute_accuracy(reloaded_model, tokenized_test_dataset)
reload_total, reload_nonzero = count_nonzero_params(reloaded_model)
print(f"Reloaded sparse model:    Acc={reload_acc:.4f}, Nonzeros={reload_nonzero}/{reload_total}")

# ---- 10. Fine-tuning ----
finetune_args_pruned = TrainingArguments(
    output_dir="./tmp_pruned_bert",
    per_device_train_batch_size=16,
    num_train_epochs=1,
    learning_rate=2e-5,  # Lower learning rate for fine-tuning
    logging_steps=100,
    save_strategy="no",
    report_to=[]
)
trainer_pruned = Trainer(
    model=reloaded_model,
    args=finetune_args_pruned,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_test_dataset,
    tokenizer=tokenizer,
)
print("\n=== Fine-tuning Pruned Model ===")
trainer_pruned.train()
finetuned_acc = compute_accuracy(reloaded_model, tokenized_test_dataset)
finetuned_size = get_model_size(reloaded_model)
finetuned_total, finetuned_nonzero = count_nonzero_params(reloaded_model)
print(f"Fine-tuned pruned model:  Size={finetuned_size:.2f} MB, Acc={finetuned_acc:.4f}, Nonzeros={finetuned_nonzero}/{finetuned_total}")

# ---- 11. Final summary ----
dense_size = os.path.getsize("dense_pruned_model.pt") / (1024*1024)
print(f"Dense .pt pruned model size: {dense_size:.2f} MB")
print(f"Sparse export folder size:   {sparse_disk_size:.2f} MB")

print("\n=== Summary ===")
print(f"Original   : {orig_size:.2f} MB, Acc={orig_acc:.4f}, Nonzeros={orig_nonzero}")
print(f"Pruned     : {pruned_size:.2f} MB, Acc={pruned_acc:.4f}, Nonzeros={pruned_nonzero}")
print(f"Finetuned  : {finetuned_size:.2f} MB, Acc={finetuned_acc:.4f}, Nonzeros={finetuned_nonzero}")
print(f"Sparse export folder size:   {sparse_disk_size:.2f} MB")
print(f"Dense .pt pruned model size: {dense_size:.2f} MB")
print(f"Reloaded   : Acc={reload_acc:.4f}, Nonzeros={reload_nonzero}")