In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from timm import create_model  # Import for Swin Transformer

# Model and pruning configuration (adjust as needed)
model_name = "swin_tiny_patch4_window7_224"  # Choose the desired Swin Transformer model
pruning_ratio = 0.4  # Start with lower pruning ratios
num_training_epochs = 2  # Train longer before pruning
num_finetuning_epochs = 2  # Fine-tune longer to recover
learning_rate = 0.003  # Slightly lower learning rate
learning_rate_finetune = 0.001  # Fine-tuning learning rate
threshold = 0.4  

# Load dataset and apply appropriate transformations
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False)

def count_model_parameters(model):
    """
    This function counts the total number of parameters in a PyTorch model.
    """
    total_params = 0
    for param in model.parameters():
        if param.requires_grad:
            total_params += torch.count_nonzero(param)
    return total_params

def prune_model_with_svd_and_iht(model, pruning_ratio, num_iterations=10):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            weight = module.weight.data

            # Apply SVD
            u, s, v = torch.svd(weight)
            s_pruned = torch.zeros_like(s)
            s_pruned[:int(s.size(0) * (1 - pruning_ratio))] = s[:int(s.size(0) * (1 - pruning_ratio))]
            weight_pruned_svd = torch.mm(u, torch.mm(torch.diag(s_pruned), v.t()))

            # Apply IHT
            weight_pruned_iht = iht(weights=weight_pruned_svd, pruning_ratio=pruning_ratio,
                                   num_iterations=num_iterations, threshold=threshold, module=module)

            module.weight.data = weight_pruned_iht  # Apply the pruned weights


def iht(weights, pruning_ratio, num_iterations, threshold, module):
    for _ in range(num_iterations):
        # Calculate gradients using torch.autograd.grad
        gradients = torch.autograd.grad(module.weight, module.weight, grad_outputs=weights)[0]

        # Update weights using calculated gradients
        weights = torch.clamp(weights - learning_rate * gradients, -threshold, threshold)

    return weights

# Load the model once before the loop
model = create_model(model_name, pretrained=True)
num_params_before = count_model_parameters(model)
print(f"Number of parameters before pruning: {num_params_before}")

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)  # Define the optimizer
criterion = nn.CrossEntropyLoss()  # Define the loss function


# Initialize overall accuracy tracking
overall_accuracy = 0.0
num_iterations = 0

for epoch in range(num_training_epochs):
    train_loss = 0.0
    train_acc = 0.0
    for images, labels in train_loader:
        optimizer.zero_grad()  # Clear gradients for the current step
        outputs = model(images)
        logits = outputs # Access logits for Swin Transformers
        loss = criterion(logits, labels)
        loss.backward()  # Backpropagate gradients
        optimizer.step()  # Update model parameters based on gradients
        train_loss += loss.item()  # Accumulate training loss

        # Calculate training accuracy
        _, preds = torch.max(outputs, 1)
        train_acc += torch.sum(preds == labels).item() / len(labels)

    # Print training progress
    print(f"Epoch [{epoch+1}/{num_training_epochs}], Pruning Ratio: {pruning_ratio:.2f}, "
            f"Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc/len(train_loader):.4f}")


# Apply pruning with the current pruning ratio
prune_model_with_svd_and_iht(model, pruning_ratio=pruning_ratio)

# Print parameter count after pruning
num_params_after = count_model_parameters(model)
print(f"Number of parameters after pruning with ratio {pruning_ratio:.2f}: {num_params_after}")

    # Fine-tune the pruned model
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate_finetune)  # Adjust learning rate
for epoch in range(num_finetuning_epochs):
    train_loss = 0.0
    train_acc = 0.0
    for images, labels in train_loader:
        optimizer.zero_grad()  # Clear gradients for the current step
        outputs = model(images)
        logits = outputs  # Access logits for Swin Transformers
        loss = criterion(logits, labels)
        loss.backward()  # Backpropagate gradients
        optimizer.step()  # Update model parameters based on gradients
        train_loss += loss.item()  # Accumulate training loss

        # Calculate training accuracy
        _, preds = torch.max(outputs, 1)
        train_acc += torch.sum(preds == labels).item() / len(labels)

    # Print training progress
    print(f"Fine-tuning Epoch [{epoch + 1}/{num_finetuning_epochs}], Pruning Ratio: {pruning_ratio:.2f}, "
          f"Train Loss: {train_loss / len(train_loader):.4f}, Train Acc: {train_acc / len(train_loader):.4f}")


# Evaluate pruned model accuracy on the test set
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)  # Corrected indentation
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    pruned_accuracy = 100 * correct / total
    print(f"Accuracy after pruning {pruning_ratio:.2f}: {pruned_accuracy:.2f}%")


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:10<00:00, 16122659.18it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

Number of parameters before pruning: 28288354
Epoch [1/2], Pruning Ratio: 0.40, Train Loss: 0.5846, Train Acc: 0.8070
Epoch [2/2], Pruning Ratio: 0.40, Train Loss: 0.1518, Train Acc: 0.9491
Number of parameters after pruning with ratio 0.40: 28288354
Fine-tuning Epoch [1/2], Pruning Ratio: 0.40, Train Loss: 0.1614, Train Acc: 0.9460
Fine-tuning Epoch [2/2], Pruning Ratio: 0.40, Train Loss: 0.1134, Train Acc: 0.9612
Accuracy after pruning 0.40: 95.27%
