## start

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np

import copy
from transformers import BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import os
import json

import matplotlib.pyplot as plt
import time

from collections import defaultdict
import tqdm

# 1. get linear mask for effective weight with each weight size [output_size, input_size]

In [2]:
def get_linear_mask(module:nn.Module) -> torch.Tensor:
    x = module.weight.data
    output_size, input_size = x.shape
    x_norm = torch.abs(x) / torch.sum(torch.abs(x), dim=0, keepdim=True)
    neff = torch.floor(1/torch.sum((x_norm ** 2), dim=0, keepdim=True).squeeze(0))
    
    _, indices = torch.sort(x_norm, dim=0, descending=True)
    range_tensor = torch.arange(output_size, device=x.device).unsqueeze(0).expand(input_size, -1).T
    sorted_mask = range_tensor < neff
    
    mask = torch.zeros_like(x, dtype=torch.bool)
    mask.scatter_(0, indices, sorted_mask)
    return mask

# 2. set the edge with ineffective weight = 0 and prune the edge

In [3]:
def prune_model_neff(model, renormalize=False):
    model = copy.deepcopy(model)
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            mask = get_linear_mask(module).to(module.weight.device)
            with torch.no_grad():
                module.weight *= mask
                if renormalize:
                    row_sum = module.weight.sum(dim=0, keepdim=True).clamp(min=1e-8)
                    module.weight.div_(row_sum)
    return model

def model_sparsity(model):
    """Calculate the sparsity of the model"""
    total_params = 0
    zero_params = 0
    
    for name, param in model.named_parameters():
        if 'weight' in name:
            total_params += param.numel()
            zero_params += torch.sum(param == 0).item()
    
    sparsity = zero_params / total_params
    return sparsity

# 3. train a linear model first and storage

In [4]:
class LinearModel(nn.Module):
    def __init__(self, input_size, output_size, hidden_size= [512, 512, 512]):
        super(LinearModel, self).__init__()
        self.layers = nn.ModuleList()
        
        prev_size = input_size
        for size in hidden_size:
            self.layers.append(nn.Linear(prev_size, size))
            prev_size = size
            
        self.output = nn.Linear(prev_size, output_size)
        
    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten
        
        for layer in self.layers:
            x = F.relu(layer(x))        
        x = self.output(x)
        return F.log_softmax(x, dim=1)
        
def train(model, device, train_loader, optimizer, epoch):
    """Train for one epoch"""
    model.train()
    train_loss = 0
    correct = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    
    avg_loss = train_loss / len(train_loader)
    accuracy = 100. * correct / len(train_loader.dataset)
    return avg_loss, accuracy


def test(model, device, test_loader):
    """Evaluate model on test set"""
    model.eval()
    test_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    
    print(f'\nTest set: Average loss: {test_loss:.4f}, '
          f'Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)\n')
    
    return test_loss, accuracy

## data loader

In [5]:
batch_size = 64
test_batch_size = 1000
epochs = 10
lr = 3e-4

# MINIST-10 dataset
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)   
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LinearModel(input_size=28*28, output_size=10, hidden_size=[1024, 512, 512]).to(device)

optimizer = optim.Adam(model.parameters(), lr=lr)

result = {
    'train_loss': [],
    'train_accuracy': [],
    'test_loss': [],
    'test_accuracy': []
}


In [6]:
# initial test
test_loss, test_accuracy = test(model, device, test_loader)
result['test_loss'].append(test_loss)
result['test_accuracy'].append(test_accuracy)

# Training loop
for epoch in range(1, epochs + 1):
    train_loss, train_accuracy = train(model, device, train_loader, optimizer, epoch)
    result['train_loss'].append(train_loss)
    result['train_accuracy'].append(train_accuracy)
    
    # Test after each epoch
    test_loss, test_accuracy = test(model, device, test_loader)
    result['test_loss'].append(test_loss)
    result['test_accuracy'].append(test_accuracy)
    
# Save the model
torch.save(model.state_dict(), 'linear_model.pth')



Test set: Average loss: 2.3062, Accuracy: 1258/10000 (12.58%)


Test set: Average loss: 0.0909, Accuracy: 9708/10000 (97.08%)


Test set: Average loss: 0.1003, Accuracy: 9694/10000 (96.94%)


Test set: Average loss: 0.0657, Accuracy: 9795/10000 (97.95%)


Test set: Average loss: 0.0670, Accuracy: 9788/10000 (97.88%)


Test set: Average loss: 0.0758, Accuracy: 9763/10000 (97.63%)


Test set: Average loss: 0.0820, Accuracy: 9778/10000 (97.78%)


Test set: Average loss: 0.0694, Accuracy: 9823/10000 (98.23%)


Test set: Average loss: 0.0662, Accuracy: 9833/10000 (98.33%)


Test set: Average loss: 0.0877, Accuracy: 9794/10000 (97.94%)


Test set: Average loss: 0.0735, Accuracy: 9815/10000 (98.15%)



# prune the model and comparing the performance with the original model

In [7]:
pruned_model_renormalized = prune_model_neff(model, renormalize=True)
pruned_model_renormalized.to(device)

# Test the pruned model
test_loss, test_accuracy = test(pruned_model_renormalized, device, test_loader)
print(f'Pruned Model - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
# Save the pruned model
torch.save(pruned_model_renormalized.state_dict(), 'pruned_linear_model_renormalized.pth')


prune_model = prune_model_neff(model, renormalize=False)
prune_model.to(device)
# Test the pruned model without renormalization
test_loss, test_accuracy = test(prune_model, device, test_loader)
print(f'Pruned Model without Renormalization - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
# Save the pruned model without renormalization
torch.save(prune_model.state_dict(), 'pruned_linear_model.pth')


Test set: Average loss: 28035697270544202374587336359936.0000, Accuracy: 8889/10000 (88.89%)

Pruned Model - Test Loss: 28035697270544202374587336359936.0000, Test Accuracy: 88.89%

Test set: Average loss: 0.0667, Accuracy: 9813/10000 (98.13%)

Pruned Model without Renormalization - Test Loss: 0.0667, Test Accuracy: 98.13%


In [8]:
# test 10 times and show the average performance
test_loss_acc = {'prune_loss': [], 'prune_accuracy': [], 'prune_renorm_loss': [], 'prune_renorm_accuracy': []}

for i in range(10):
    pruned_model_renormalized = prune_model_neff(model, renormalize=True)
    pruned_model_renormalized.to(device)

    # Test the pruned model
    test_loss, test_accuracy = test(pruned_model_renormalized, device, test_loader)
    print(f'Pruned Model - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
    test_loss_acc['prune_renorm_loss'].append(test_loss)
    test_loss_acc['prune_renorm_accuracy'].append(test_accuracy)
    
    pruned_model = prune_model_neff(model, renormalize=False)
    pruned_model.to(device)
    # Test the pruned model without renormalization
    test_loss, test_accuracy = test(pruned_model, device, test_loader)
    print(f'Pruned Model without Renormalization - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
    test_loss_acc['prune_loss'].append(test_loss)
    test_loss_acc['prune_accuracy'].append(test_accuracy)
    
# average performance
avg_prune_loss = np.mean(test_loss_acc['prune_loss'])
avg_prune_accuracy = np.mean(test_loss_acc['prune_accuracy'])
avg_prune_renorm_loss = np.mean(test_loss_acc['prune_renorm_loss'])
avg_prune_renorm_accuracy = np.mean(test_loss_acc['prune_renorm_accuracy'])

print(f'Average Pruned Model - Test Loss: {avg_prune_loss:.4f}, Test Accuracy: {avg_prune_accuracy:.2f}%')
print(f'Average Pruned Model with Renormalization - Test Loss: {avg_prune_renorm_loss:.4f}, Test Accuracy: {avg_prune_renorm_accuracy:.2f}%')
    


Test set: Average loss: 28035697270544202374587336359936.0000, Accuracy: 8889/10000 (88.89%)

Pruned Model - Test Loss: 28035697270544202374587336359936.0000, Test Accuracy: 88.89%

Test set: Average loss: 0.0667, Accuracy: 9813/10000 (98.13%)

Pruned Model without Renormalization - Test Loss: 0.0667, Test Accuracy: 98.13%

Test set: Average loss: 28035697270544202374587336359936.0000, Accuracy: 8889/10000 (88.89%)

Pruned Model - Test Loss: 28035697270544202374587336359936.0000, Test Accuracy: 88.89%

Test set: Average loss: 0.0667, Accuracy: 9813/10000 (98.13%)

Pruned Model without Renormalization - Test Loss: 0.0667, Test Accuracy: 98.13%

Test set: Average loss: 28035697270544202374587336359936.0000, Accuracy: 8889/10000 (88.89%)

Pruned Model - Test Loss: 28035697270544202374587336359936.0000, Test Accuracy: 88.89%

Test set: Average loss: 0.0667, Accuracy: 9813/10000 (98.13%)

Pruned Model without Renormalization - Test Loss: 0.0667, Test Accuracy: 98.13%

Test set: Average los

## measure the sparsity of the model

In [11]:
sparsity = model_sparsity(pruned_model_renormalized)
print(f'Sparsity of the pruned model with renormalization: {sparsity:.4f}')
sparsity = model_sparsity(prune_model)
print(f'Sparsity of the pruned model without renormalization: {sparsity:.4f}')
sparsity = model_sparsity(model)
print(f'Sparsity of the original model: {sparsity:.4f}')

Sparsity of the pruned model with renormalization: 0.3418
Sparsity of the pruned model without renormalization: 0.3418
Sparsity of the original model: 0.0000


## one-shot fine-tuning

In [9]:
pruned_model_one_shot = LinearModel(input_size=28*28, output_size=10, hidden_size=[1024, 512, 512]).to(device)
pruned_model_one_shot.load_state_dict(torch.load('pruned_linear_model.pth'))

# Test the pruned model
test_loss, test_accuracy = test(pruned_model_one_shot, device, test_loader)
print(f'Pruned Model - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')


Test set: Average loss: 0.0667, Accuracy: 9813/10000 (98.13%)

Pruned Model - Test Loss: 0.0667, Test Accuracy: 98.13%


In [10]:
train_loss, train_accuracy = train(pruned_model_one_shot, device, train_loader, optimizer, 1)
result['train_loss'].append(train_loss)
result['train_accuracy'].append(train_accuracy)
    
    # Test after each epoch
test_loss, test_accuracy = test(pruned_model_one_shot, device, test_loader)
result['test_loss'].append(test_loss)
result['test_accuracy'].append(test_accuracy)

sparsity = model_sparsity(pruned_model_one_shot)
print(f'Model Sparsity: {sparsity:.2%}')


Test set: Average loss: 0.0667, Accuracy: 9813/10000 (98.13%)

Model Sparsity: 34.18%


# test in huggingface language model

In [5]:
def compute_accuracy(model, tokenized_dataset, batch_size=32):
    from torch.utils.data import DataLoader
    model.eval()
    device = next(model.parameters()).device
    correct = 0
    total = 0
    loader = DataLoader(tokenized_dataset, batch_size=batch_size)
    with torch.no_grad():
        for batch in loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k in ['input_ids', 'attention_mask']}
            labels = batch['labels'].to(device)
            outputs = model(**inputs)
            preds = outputs.logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += len(labels)
    return correct / total

In [13]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
raw_dataset = load_dataset("ag_news")
train_dataset = raw_dataset["train"]
test_dataset = raw_dataset["test"]
def tokenize_function(example):
    return tokenizer(
        example["text"],
        truncation=True,
        padding="max_length",
        max_length=128
    )

# Tokenize and format
tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True)
tokenized_test_dataset = test_dataset.map(tokenize_function, batched=True)
tokenized_train_dataset = tokenized_train_dataset.rename_column("label", "labels")
tokenized_test_dataset = tokenized_test_dataset.rename_column("label", "labels")
tokenized_train_dataset.set_format(type="torch", columns=['input_ids', 'attention_mask', 'labels'])
tokenized_test_dataset.set_format(type="torch", columns=['input_ids', 'attention_mask', 'labels'])

# 4. Load and fine-tune BERT on train split
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4).to(device)
finetune_args = TrainingArguments(
    output_dir="./tmp_finetuned_bert",
    per_device_train_batch_size=16,
    num_train_epochs=5,
    logging_steps=1000,
    save_strategy="no",
    report_to=[]
)
trainer = Trainer(
    model=model,
    args=finetune_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_test_dataset,
    tokenizer=tokenizer,
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


In [13]:
print("\n=== Training Original Model ===")
trainer.train()

orig_acc = compute_accuracy(model, tokenized_test_dataset)
sparsity = model_sparsity(model)
print(f"\nFine-tuned original model: Sparsity={sparsity:.4f}, Acc={orig_acc:.4f}")

torch.save(model.state_dict(), 'bert_origin.pth')


=== Training Original Model ===


Step,Training Loss
1000,0.3472
2000,0.2622
3000,0.2706
4000,0.241
5000,0.2265
6000,0.2271
7000,0.2213
8000,0.1992
9000,0.1664
10000,0.1628



Fine-tuned original model: Sparsity=0.0000, Acc=0.9428


In [14]:
pruned_model_renormalized = prune_model_neff(model, renormalize=True)
pruned_model_renormalized.to(device)
pruned_model = prune_model_neff(model, renormalize=False)
pruned_model.to(device)

prune_acc = compute_accuracy(pruned_model, tokenized_test_dataset)
prune_renorm_acc = compute_accuracy(pruned_model_renormalized, tokenized_test_dataset)
pruned_sparsity = model_sparsity(pruned_model)
pruned_renorm_sparsity = model_sparsity(pruned_model_renormalized)

print(f"\nPruned model: Sparsity={pruned_sparsity:.4f}, Acc={prune_acc:.4f}")
print(f"Pruned model with renormalization: Sparsity={pruned_renorm_sparsity:.4f}, Acc={prune_renorm_acc:.4f}")


Pruned model: Sparsity=0.2923, Acc=0.9447
Pruned model with renormalization: Sparsity=0.2923, Acc=0.2500


In [15]:
torch.save(pruned_model.state_dict(), 'bert_pruned.pth')
torch.save(pruned_model_renormalized.state_dict(), 'bert_pruned_renorm.pth')

In [15]:
pruned_model_one_shot = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4).to(device)
pruned_model_one_shot.load_state_dict(torch.load('bert_pruned.pth'))

# Test the pruned model
inital_acc = compute_accuracy(pruned_model_one_shot, tokenized_test_dataset)
print(f'Pruned Model - Test Accuracy: {inital_acc:.4f}')

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Pruned Model - Test Accuracy: 0.9447


In [16]:
finetune_args = TrainingArguments(
    output_dir="./tmp_finetuned_bert_pruned",
    per_device_train_batch_size=16,
    num_train_epochs=1,
    logging_steps=1000,
    save_strategy="no",
    report_to=[]
)

trainer = Trainer(
    model=pruned_model_one_shot,
    args=finetune_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_test_dataset,
    tokenizer=tokenizer,
)
print("\n=== Training Pruned Model ===")
trainer.train()
pruned_acc = compute_accuracy(pruned_model_one_shot, tokenized_test_dataset)
print(f"Pruned model after fine-tuning: Acc={pruned_acc:.4f}")

sparsity = model_sparsity(pruned_model_one_shot)
print(f'Sparsity of the pruned model after fine-tuning: {sparsity:.4f}')

  trainer = Trainer(



=== Training Pruned Model ===


Step,Training Loss
1000,0.0771
2000,0.0804
3000,0.0912
4000,0.0729
5000,0.064
6000,0.056
7000,0.0554


Pruned model after fine-tuning: Acc=0.9421
Sparsity of the pruned model after fine-tuning: 0.0000


## test for LLM

In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from datasets import load_dataset
import torch
from tqdm import tqdm

# Configuration
MODEL_NAME = "Qwen/Qwen-7B"  # Use "Qwen/Qwen-7B" for smaller variant
DATASET_NAME = "wikitext"
DATASET_CONFIG = "wikitext-2-raw-v1"
DEVICE_MAP = "auto"  # Automatically distributes across GPUs
BATCH_SIZE = 1  # Reduce if OOM errors occur

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

# Load dataset
test_dataset = load_dataset(DATASET_NAME, DATASET_CONFIG, split="test")
texts = [text for text in test_dataset["text"] if text.strip()]  # Remove empty strings


In [5]:
# Load model with quantization (4-bit) to reduce VRAM usage
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map=DEVICE_MAP,
    torch_dtype=torch.float16,
    quantization_config={"load_in_4bit": True},
    trust_remote_code=True
)

The model is automatically converting to bf16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".
Try importing flash-attention for faster inference...


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [6]:
# Calculate perplexity
model.eval()
total_log_likelihood = 0
total_tokens = 0

test_text = texts[:100]  # Limit to first 100 texts for testing

with torch.no_grad():
    for text in tqdm(test_text, desc="Calculating Perplexity"):
        # Tokenize text
        inputs = tokenizer(text, return_tensors="pt", truncation=True).to(model.device)
        
        # Forward pass to get loss
        outputs = model(**inputs, labels=inputs["input_ids"])
        loss = outputs.loss.item()
        
        # Update metrics
        total_log_likelihood += loss * inputs["input_ids"].size(1)
        total_tokens += inputs["input_ids"].size(1)

# Final perplexity calculation
perplexity = torch.exp(torch.tensor(total_log_likelihood / total_tokens)).item()
print(f"Perplexity: {perplexity:.2f}")

sparsity = model_sparsity(model)
print(f'Sparsity of the original model: {sparsity:.4f}')

torch.save(model.state_dict(), 'qwen_model.pth')

Calculating Perplexity: 100%|██████████| 100/100 [00:06<00:00, 15.35it/s]


Perplexity: 19.62
Sparsity of the original model: 0.0000


In [5]:
def prune_model_neff_llm(model, renormalize=False):
    """
    Prune LLM model by targeting only standard Linear layers
    Avoids quantized layers and special layer types
    """
    model = copy.deepcopy(model)
    pruned_layers = []
    
    for name, module in model.named_modules():
        # Only prune standard nn.Linear layers, avoid quantized layers
        if isinstance(module, nn.Linear) and not hasattr(module, 'quant_state'):
            try:
                mask = get_linear_mask(module).to(module.weight.device)
                with torch.no_grad():
                    module.weight *= mask.float()
                    
                    if renormalize:
                        # More stable renormalization
                        row_sum = module.weight.abs().sum(dim=1, keepdim=True).clamp(min=1e-8)
                        module.weight.div_(row_sum)
                    
                    pruned_layers.append(name)
                    
                    # Check sparsity of this layer
                    sparsity = (module.weight == 0).float().mean().item()
                    print(f"Pruned {name}: {sparsity:.2%} sparsity")
                    
            except Exception as e:
                print(f"Skipping {name}: {e}")
                continue
    
    print(f"Successfully pruned {len(pruned_layers)} layers")
    return model

In [8]:
def prune_model_neff_inplace(model, renormalize=False):
    """
    Prune model in-place to avoid memory issues
    """
    pruned_layers = []
    
    for name, module in model.named_modules():
        # Only prune standard nn.Linear layers, avoid quantized layers
        if isinstance(module, nn.Linear) and not hasattr(module, 'quant_state'):
            try:
                # Create mask
                mask = get_linear_mask(module).to(module.weight.device)
                
                # Apply pruning in-place
                with torch.no_grad():
                    module.weight.data *= mask.float()
                    
                    if renormalize:
                        row_sum = module.weight.data.abs().sum(dim=1, keepdim=True).clamp(min=1e-8)
                        module.weight.data.div_(row_sum)
                    
                    pruned_layers.append(name)
                    sparsity = (module.weight.data == 0).float().mean().item()
                    print(f"Pruned {name}: {sparsity:.2%} sparsity")
                    
            except Exception as e:
                print(f"Skipping {name}: {e}")
                continue
    
    print(f"Successfully pruned {len(pruned_layers)} layers")
    return model

In [9]:
pruned_Qwen = prune_model_neff_inplace(model, renormalize=False)

pruned_Qwen.to("cuda" if torch.cuda.is_available() else "cpu")

# Calculate perplexity
pruned_Qwen.eval()
total_log_likelihood = 0
total_tokens = 0

with torch.no_grad():
    for text in tqdm(test_text, desc="Calculating Perplexity for Pruned Model"):
        inputs = tokenizer(text, return_tensors="pt", truncation=True).to(pruned_Qwen.device)
        outputs = pruned_Qwen(**inputs, labels=inputs["input_ids"])
        loss = outputs.loss.item()
        
        total_log_likelihood += loss * inputs["input_ids"].size(1)
        total_tokens += inputs["input_ids"].size(1)
        
perplexity_pruned = torch.exp(torch.tensor(total_log_likelihood / total_tokens)).item()
print(f"Perplexity of Pruned Model: {perplexity_pruned:.2f}")

sparsity = model_sparsity(pruned_Qwen)
print(f'Sparsity of the pruned model: {sparsity:.4f}')

torch.save(pruned_Qwen.state_dict(), 'pruned_qwen_model.pth')

Pruned lm_head: 39.97% sparsity
Successfully pruned 1 layers


Calculating Perplexity for Pruned Model: 100%|██████████| 100/100 [00:15<00:00,  6.52it/s]


Perplexity of Pruned Model: 21.05
Sparsity of the pruned model: 0.0555


In [5]:
def prune_model_neff_inplace(model, renormalize=False):
    """
    Prune model in-place to avoid memory issues
    """
    pruned_layers = []
    
    for name, module in model.named_modules():
        # Only prune standard nn.Linear layers
        if isinstance(module, nn.Linear) and not hasattr(module, 'quant_state'):
            try:
                # Create mask
                mask = get_linear_mask(module).to(module.weight.device)
                
                # Apply pruning in-place
                with torch.no_grad():
                    module.weight.data *= mask.float()
                    
                    if renormalize:
                        row_sum = module.weight.data.abs().sum(dim=1, keepdim=True).clamp(min=1e-8)
                        module.weight.data.div_(row_sum)
                    
                    pruned_layers.append(name)
                    sparsity = (module.weight.data == 0).float().mean().item()
                    print(f"Pruned {name}: {sparsity:.2%} sparsity")
                    
            except Exception as e:
                print(f"Skipping {name}: {e}")
                continue
    
    print(f"Successfully pruned {len(pruned_layers)} layers")
    return model

def calculate_perplexity(model, tokenizer, texts, max_length=512):
    """
    Calculate perplexity for a list of texts
    """
    model.eval()
    total_log_likelihood = 0
    total_tokens = 0
    
    with torch.no_grad():
        for text in tqdm(texts, desc="Calculating Perplexity"):
            # Tokenize text
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length).to(model.device)
            
            # Forward pass to get loss
            outputs = model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss.item()
            
            # Update metrics
            total_log_likelihood += loss * inputs["input_ids"].size(1)
            total_tokens += inputs["input_ids"].size(1)
    
    # Final perplexity calculation
    perplexity = torch.exp(torch.tensor(total_log_likelihood / total_tokens)).item()
    return perplexity

In [6]:
print("Loading model for pruning...")
model_for_pruning = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map=DEVICE_MAP,
    torch_dtype=torch.float16,  # No quantization
    trust_remote_code=True
)

# Step 2: Calculate original perplexity
print("Calculating original perplexity...")
original_perplexity = calculate_perplexity(model_for_pruning, tokenizer, texts[:100])
print(f"Original Perplexity: {original_perplexity:.2f}")

Loading model for pruning...


The model is automatically converting to bf16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".
Try importing flash-attention for faster inference...


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Calculating original perplexity...


Calculating Perplexity: 100%|██████████| 100/100 [00:03<00:00, 29.09it/s]

Original Perplexity: 17.18





In [None]:


# Step 3: Prune the model in-place
print("Pruning model...")
pruned_model = prune_model_neff_inplace(model_for_pruning, renormalize=False)

# Step 4: Calculate pruned perplexity
print("Calculating pruned perplexity...")
pruned_perplexity = calculate_perplexity(pruned_model, tokenizer, texts[:100])
print(f"Pruned Perplexity: {pruned_perplexity:.2f}")

# Step 5: Calculate sparsity
sparsity = model_sparsity(pruned_model)
print(f'Model Sparsity: {sparsity:.4f}')

# Step 6: Save the pruned model (optional)
torch.save(pruned_model.state_dict(), 'pruned_qwen_fp16.pth')

Loading model for pruning...


The model is automatically converting to bf16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".
Try importing flash-attention for faster inference...


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Calculating original perplexity...


Calculating Perplexity: 100%|██████████| 100/100 [00:03<00:00, 30.28it/s]


Original Perplexity: 17.18
Pruning model...
Pruned transformer.h.0.attn.c_attn: 44.46% sparsity
Pruned transformer.h.0.attn.c_proj: 37.45% sparsity
Pruned transformer.h.0.mlp.w1: 36.58% sparsity
Pruned transformer.h.0.mlp.w2: 36.59% sparsity
Pruned transformer.h.0.mlp.c_proj: 37.07% sparsity
Pruned transformer.h.1.attn.c_attn: 46.15% sparsity
Pruned transformer.h.1.attn.c_proj: 37.27% sparsity
Pruned transformer.h.1.mlp.w1: 36.65% sparsity
Pruned transformer.h.1.mlp.w2: 36.79% sparsity
Pruned transformer.h.1.mlp.c_proj: 37.05% sparsity
Pruned transformer.h.2.attn.c_attn: 45.42% sparsity
Pruned transformer.h.2.attn.c_proj: 37.34% sparsity
Pruned transformer.h.2.mlp.w1: 36.67% sparsity
Pruned transformer.h.2.mlp.w2: 36.85% sparsity
Pruned transformer.h.2.mlp.c_proj: 36.86% sparsity
Pruned transformer.h.3.attn.c_attn: 42.56% sparsity
Pruned transformer.h.3.attn.c_proj: 37.33% sparsity
Pruned transformer.h.3.mlp.w1: 36.67% sparsity
Pruned transformer.h.3.mlp.w2: 36.65% sparsity
Pruned tran

Calculating Perplexity: 100%|██████████| 100/100 [00:03<00:00, 31.89it/s]


Pruned Perplexity: 30.47
Model Sparsity: 0.3484


In [5]:
pruned_model_one_shot = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map=DEVICE_MAP,
    torch_dtype=torch.float16,
    #quantization_config={"load_in_4bit": True},
    trust_remote_code=True
)
pruned_model_one_shot.load_state_dict(torch.load('pruned_qwen_fp16.pth'))

# Test the pruned model
pruned_model_one_shot.eval()
total_log_likelihood = 0
total_tokens = 0

test_text = texts[:100]  # Limit to first 100 texts for testing

with torch.no_grad():
    for text in tqdm(test_text, desc="Calculating Perplexity"):
        # Tokenize text
        inputs = tokenizer(text, return_tensors="pt", truncation=True).to(pruned_model_one_shot.device)
        
        # Forward pass to get loss
        outputs = pruned_model_one_shot(**inputs, labels=inputs["input_ids"])
        loss = outputs.loss.item()
        
        # Update metrics
        total_log_likelihood += loss * inputs["input_ids"].size(1)
        total_tokens += inputs["input_ids"].size(1)

# Final perplexity calculation
perplexity = torch.exp(torch.tensor(total_log_likelihood / total_tokens)).item()
print(f"Perplexity: {perplexity:.2f}")

finetune_args = TrainingArguments(
    output_dir="./tmp_finetuned_bert_pruned",
    per_device_train_batch_size=16,
    num_train_epochs=1,
    learning_rate=2e-5,
    logging_steps=1000,
    save_strategy="no",
    report_to=[]
)

trainer = Trainer(
    model=pruned_model_one_shot,
    args=finetune_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_test_dataset,
    tokenizer=tokenizer,
)
print("\n=== Training Pruned Model ===")
trainer.train()

# Test the pruned model
pruned_model_one_shot.eval()
total_log_likelihood = 0
total_tokens = 0

test_text = texts[:100]  # Limit to first 100 texts for testing

with torch.no_grad():
    for text in tqdm(test_text, desc="Calculating Perplexity"):
        # Tokenize text
        inputs = tokenizer(text, return_tensors="pt", truncation=True).to(pruned_model_one_shot.device)
        
        # Forward pass to get loss
        outputs = pruned_model_one_shot(**inputs, labels=inputs["input_ids"])
        loss = outputs.loss.item()
        
        # Update metrics
        total_log_likelihood += loss * inputs["input_ids"].size(1)
        total_tokens += inputs["input_ids"].size(1)

# Final perplexity calculation
perplexity = torch.exp(torch.tensor(total_log_likelihood / total_tokens)).item()
print(f"Perplexity: {perplexity:.2f}")

sparsity = model_sparsity(pruned_Qwen)
print(f'Sparsity of the pruned model after fine-tuning: {sparsity:.4f}')

The model is automatically converting to bf16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".
Try importing flash-attention for faster inference...


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Calculating Perplexity: 100%|██████████| 100/100 [00:03<00:00, 29.78it/s]


Perplexity: 30.47


NameError: name 'tokenized_train_dataset' is not defined

In [7]:
model_for_pruning

QWenLMHeadModel(
  (transformer): QWenModel(
    (wte): Embedding(151936, 4096)
    (drop): Dropout(p=0.0, inplace=False)
    (rotary_emb): RotaryEmbedding()
    (h): ModuleList(
      (0-31): 32 x QWenBlock(
        (ln_1): RMSNorm()
        (attn): QWenAttention(
          (c_attn): Linear(in_features=4096, out_features=12288, bias=True)
          (c_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): RMSNorm()
        (mlp): QWenMLP(
          (w1): Linear(in_features=4096, out_features=11008, bias=False)
          (w2): Linear(in_features=4096, out_features=11008, bias=False)
          (c_proj): Linear(in_features=11008, out_features=4096, bias=False)
        )
      )
    )
    (ln_f): RMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=151936, bias=False)
)

In [7]:
def prune_model_FFN(model, renormalize=False):
    """
    Prune only the final lm_head linear layer
    """
    pruned_layers = []
    
    for name, module in model.named_modules():
        # Only prune the lm_head layer
        if name == 'lm_head' and isinstance(module, nn.Linear):
            try:
                # Create mask
                mask = get_linear_mask(module).to(module.weight.device)
                
                # Apply pruning in-place
                with torch.no_grad():
                    module.weight.data *= mask.float()
                    
                    if renormalize:
                        row_sum = module.weight.data.abs().sum(dim=1, keepdim=True).clamp(min=1e-8)
                        module.weight.data.div_(row_sum)
                    
                    pruned_layers.append(name)
                    sparsity = (module.weight.data == 0).float().mean().item()
                    print(f"Pruned {name}: {sparsity:.2%} sparsity")
                    
            except Exception as e:
                print(f"Skipping {name}: {e}")
                continue
            
        if name == 'mlp' and isinstance(module, nn.Linear):
            try:
                # Create mask for MLP layers
                mask = get_linear_mask(module).to(module.weight.device)
                
                # Apply pruning in-place
                with torch.no_grad():
                    module.weight.data *= mask.float()
                    
                    if renormalize:
                        row_sum = module.weight.data.abs().sum(dim=1, keepdim=True).clamp(min=1e-8)
                        module.weight.data.div_(row_sum)
                    
                    pruned_layers.append(name)
                    sparsity = (module.weight.data == 0).float().mean().item()
                    print(f"Pruned {name}: {sparsity:.2%} sparsity")
                    
            except Exception as e:
                print(f"Skipping {name}: {e}")
                continue
    
    if len(pruned_layers) == 0:
        print("Warning: No lm_head layer found to prune!")
        # Let's check what layers are available
        print("Available layers:")
        for name, module in model.named_modules():
            if isinstance(module, nn.Linear):
                print(f"  - {name}: {module.weight.shape}")
    else:
        print(f"Successfully pruned {len(pruned_layers)} layers")
    
    return model

In [8]:
print("Pruning model...")
pruned_model = prune_model_FFN(model_for_pruning, renormalize=False)

# Step 4: Calculate pruned perplexity
print("Calculating pruned perplexity...")
pruned_perplexity = calculate_perplexity(pruned_model, tokenizer, texts[:100])
print(f"Pruned Perplexity: {pruned_perplexity:.2f}")

# Step 5: Calculate sparsity
sparsity = model_sparsity(pruned_model)
print(f'Model Sparsity: {sparsity:.4f}')

# Step 6: Save the pruned model (optional)
torch.save(pruned_model.state_dict(), 'pruned_qwen_FFN.pth')

Pruning model...
Pruned lm_head: 39.97% sparsity
Successfully pruned 1 layers
Calculating pruned perplexity...


Calculating Perplexity: 100%|██████████| 100/100 [01:11<00:00,  1.41it/s]


Pruned Perplexity: 18.52
Model Sparsity: 0.0322


test only


In [None]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import bitsandbytes as bnb
from tqdm import tqdm

# Use your HF token
hf_token = "your personal token"

# Load tokenizer and 8-bit model (for your 5090 GPU)
model_id = "meta-llama/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    load_in_8bit=True,
    device_map="auto",
    torch_dtype=torch.float16,
    token=hf_token
)
model.eval()

# ------------------ N_eff pruning (per row) ------------------

def ls_recover_weights(layer, mask, tol=1e-5):
    """
    Least-squares recovery: after pruning, project original weight matrix
    onto the space of preserved rows to minimize reconstruction error.

    Args:
        layer: nn.Linear module (with pruned weights).
        mask: binary mask (same shape as layer.weight).
        tol: tolerance for pseudo-inverse stability.
    """
    W = layer.weight.data.float()
    W_orig = W.clone()
    M = mask.float()

    # Skip if layer is too small
    if W.shape[0] < 4 or W.shape[1] < 4:
        return

    try:
        # Reconstruct W using only preserved weights
        W_preserved = W * M
        # Project W_orig onto the row space of W_preserved
        # Equivalent to: W_hat = argmin ||W_hat - W_orig||_F s.t. support(W_hat) = M
        # Closed-form: W_hat = M ⊙ (W_orig @ V @ V.T) if W_preserved ≈ U S V.T
        U, S, Vh = torch.linalg.svd(W_preserved, full_matrices=False)
        S_inv = torch.diag(1.0 / (S + tol))
        W_hat = (U @ S_inv @ (U.T @ W_orig))
        W_hat = W_hat * M
        layer.weight.data.copy_(W_hat.to(layer.weight.dtype))

    except Exception as e:
        print(f"[WARNING] SVD failed on layer {layer}, skipping LS correction: {e}")


def get_neff_mask_per_row(weight: torch.Tensor) -> torch.Tensor:
    """
    Compute N_eff and create mask per row.
    weight: [out_dim, in_dim]
    Returns a boolean mask of same shape.
    """
    x = weight.abs().float()
    x_norm = x / x.sum(dim=1, keepdim=True).clamp(min=1e-8)
    neff = torch.floor(1.0 / (x_norm ** 2).sum(dim=1, keepdim=True)).int()

    sorted_vals, sorted_idx = torch.sort(x_norm, dim=1, descending=True)
    mask = torch.zeros_like(x, dtype=torch.bool)

    for i in range(x.size(0)):  # row-wise
        k = neff[i].item()
        mask[i, sorted_idx[i, :k]] = True

    return mask

def prune_linear_layer_per_row(layer: nn.Linear):
    if not hasattr(layer, "weight"):
        return 0, 0
    with torch.no_grad():
        weight = layer.weight.data.clone()
        mask = get_neff_mask_per_row(weight)
        layer.weight.data *= mask.to(layer.weight.device)
    return mask.sum().item(), mask.numel()

def prune_llama_per_row(model):
    kept, total = 0, 0
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and any(k in name for k in [
            "q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"
        ]):
            k, t = prune_linear_layer_per_row(module)
            kept += k
            total += t
    sparsity = 1 - kept / total
    print(f"[PRUNE] Kept {kept} / {total} weights → Sparsity: {sparsity:.4f}")
    return model

# ------------------ Evaluate perplexity ------------------

def compute_perplexity(model, tokenizer, dataset, max_length=512):
    losses = []
    for sample in tqdm(dataset, desc="Evaluating PPL"):
        encoded = tokenizer(
            sample["text"],
            return_tensors="pt",
            truncation=True,
            max_length=max_length,
            add_special_tokens=False
        )
        input_ids = encoded.input_ids

        # Manually prepend BOS token if model has one
        if tokenizer.bos_token_id is not None:
            bos = torch.tensor([[tokenizer.bos_token_id]], device=input_ids.device)
            input_ids = torch.cat([bos, input_ids.to(input_ids.device)], dim=-1)

        if input_ids.shape[-1] < 4:
            continue  # skip too-short sequences

        input_ids = input_ids.to(model.device)

        with torch.no_grad():
            outputs = model(input_ids, labels=input_ids)
            loss = outputs.loss.item()
            losses.append(loss)
    ppl = torch.exp(torch.tensor(losses).mean())
    print(f"[RESULT] Perplexity: {ppl:.2f}")
    return ppl.item()

# ------------------ Run test ------------------

# Load small test subset of WikiText2
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test[:2%]")

# Evaluate original
print("==> Original model:")
ppl_before = compute_perplexity(model, tokenizer, dataset)

# Apply pruning
print("\n==> Pruning model (per row)...")
model = prune_llama_per_row(model)

# Evaluate pruned model
print("\n==> Pruned model:")
ppl_after = compute_perplexity(model, tokenizer, dataset)

print("\n==> Summary:")
print(f"Original PPL: {ppl_before:.2f}")
print(f"Pruned   PPL: {ppl_after:.2f}")




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Filter:   0%|          | 0/4358 [00:00<?, ? examples/s]

Map:   0%|          | 0/1760 [00:00<?, ? examples/s]

==> Evaluating original model...


 33%|███▎      | 1169/3519 [07:35<15:14,  2.57it/s]


KeyboardInterrupt: 

In [None]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import bitsandbytes as bnb
from tqdm import tqdm

# HF access token
hf_token = "your personal token"

# Load LLaMA-2 tokenizer and model
model_id = "meta-llama/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    load_in_8bit=True,
    device_map="auto",
    torch_dtype=torch.float16,
    token=hf_token
)
model.eval()

# ------------------ N_eff masks ------------------

def get_neff_mask_per_row(weight: torch.Tensor) -> torch.Tensor:
    x = weight.abs().float()
    x_norm = x / x.sum(dim=1, keepdim=True).clamp(min=1e-8)
    neff = torch.floor(1.0 / (x_norm ** 2).sum(dim=1, keepdim=True)).int()
    sorted_vals, sorted_idx = torch.sort(x_norm, dim=1, descending=True)
    mask = torch.zeros_like(x, dtype=torch.bool)
    for i in range(x.size(0)):
        k = neff[i].item()
        mask[i, sorted_idx[i, :k]] = True
    return mask

def get_neff_mask_per_column(weight: torch.Tensor) -> torch.Tensor:
    x = weight.abs().float()
    x_norm = x / x.sum(dim=0, keepdim=True).clamp(min=1e-8)
    neff = torch.floor(1.0 / (x_norm ** 2).sum(dim=0, keepdim=True)).int()
    sorted_vals, sorted_idx = torch.sort(x_norm, dim=0, descending=True)
    mask = torch.zeros_like(x, dtype=torch.bool)
    for j in range(x.size(1)):
        k = neff[0, j].item()
        mask[sorted_idx[:k, j], j] = True
    return mask

# ------------------ Pruning ------------------

def prune_linear_layer(layer: nn.Linear, mask: torch.Tensor):
    with torch.no_grad():
        layer.weight.data *= mask.to(layer.weight.device)

def prune_llama(model, mask_fn):
    kept, total = 0, 0
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and any(k in name for k in [
            "q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"
        ]):
            weight = module.weight.data.clone()
            mask = mask_fn(weight)
            prune_linear_layer(module, mask)
            kept += mask.sum().item()
            total += mask.numel()
    sparsity = 1 - kept / total
    print(f"[PRUNE] Kept {kept} / {total} weights → Sparsity: {sparsity:.4f}")
    return model

# ------------------ Evaluation ------------------

def compute_perplexity(model, tokenizer, dataset, max_length=512):
    losses = []
    for sample in tqdm(dataset, desc="Evaluating PPL"):
        encoded = tokenizer(
            sample["text"],
            return_tensors="pt",
            truncation=True,
            max_length=max_length,
            add_special_tokens=False
        )
        input_ids = encoded.input_ids
        if tokenizer.bos_token_id is not None:
            bos = torch.tensor([[tokenizer.bos_token_id]], device=input_ids.device)
            input_ids = torch.cat([bos, input_ids.to(input_ids.device)], dim=-1)
        if input_ids.shape[-1] < 4:
            continue
        input_ids = input_ids.to(model.device)
        with torch.no_grad():
            outputs = model(input_ids, labels=input_ids)
            loss = outputs.loss.item()
            losses.append(loss)
    ppl = torch.exp(torch.tensor(losses).mean())
    print(f"[RESULT] Perplexity: {ppl:.2f}")
    return ppl.item()

# ------------------ One-shot retraining ------------------

def one_shot_retrain(model, tokenizer, dataset, lr=5e-5):
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    for sample in dataset.select(range(2)):  # 2 samples for one-shot
        encoded = tokenizer(
            sample["text"],
            return_tensors="pt",
            truncation=True,
            max_length=512,
            add_special_tokens=False
        )
        input_ids = encoded.input_ids.to(model.device)
        if tokenizer.bos_token_id is not None:
            bos = torch.tensor([[tokenizer.bos_token_id]], device=input_ids.device)
            input_ids = torch.cat([bos, input_ids], dim=-1)

        if input_ids.shape[-1] < 4:
            continue

        outputs = model(input_ids, labels=input_ids)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        break  # Only one shot
    model.eval()
    return model


# ------------------ Main ------------------

# Load subset of wikitext-2
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test[:2%]")

# Clone the original model
import copy
model_row = copy.deepcopy(model)
model_col = copy.deepcopy(model)

print("==> Original model:")
ppl_orig = compute_perplexity(model, tokenizer, dataset)

print("\n==> Row-pruned model:")
prune_llama(model_row, get_neff_mask_per_row)
ppl_row = compute_perplexity(model_row, tokenizer, dataset)

print("\n==> Row-pruned + retrain:")
model_row = one_shot_retrain(model_row, tokenizer, dataset)
ppl_row_retrain = compute_perplexity(model_row, tokenizer, dataset)

print("\n==> Column-pruned model:")
prune_llama(model_col, get_neff_mask_per_column)
ppl_col = compute_perplexity(model_col, tokenizer, dataset)

print("\n==> Column-pruned + retrain:")
model_col = one_shot_retrain(model_col, tokenizer, dataset)
ppl_col_retrain = compute_perplexity(model_col, tokenizer, dataset)

print("\n==> Summary:")
print(f"Original            : {ppl_orig:.2f}")
print(f"Row-pruned          : {ppl_row:.2f}")
print(f"Row-pruned + retrain: {ppl_row_retrain:.2f}")
print(f"Column-pruned       : {ppl_col:.2f}")
print(f"Column-pruned + retrain: {ppl_col_retrain:.2f}")

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

==> Original model:


Evaluating PPL: 100%|██████████| 87/87 [00:05<00:00, 16.12it/s]


[RESULT] Perplexity: 24.23

==> Row-pruned model:
[PRUNE] Kept 4090322544 / 6476005376 weights → Sparsity: 0.3684


Evaluating PPL: 100%|██████████| 87/87 [00:05<00:00, 16.30it/s]
  scaler = torch.cuda.amp.GradScaler()  # for mixed precision
  with torch.cuda.amp.autocast():


[RESULT] Perplexity: 59.15

==> Row-pruned + retrain:


ValueError: Attempting to unscale FP16 gradients.