In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
!pip install thop
!pip install timm

Collecting thop
  Downloading thop-0.1.1.post2209072238-py3-none-any.whl.metadata (2.7 kB)
Downloading thop-0.1.1.post2209072238-py3-none-any.whl (15 kB)
Installing collected packages: thop
Successfully installed thop-0.1.1.post2209072238


In [None]:
import os
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter  # Import for TensorBoard visualization (optional)
import torch
from torch import nn
from timm import create_model
from torch import nn
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from timm import create_model  # Import for Swin Transformer

# Model and pruning configuration (adjust as needed)
model_name = "swin_tiny_patch4_window7_224"  # Choose the desired Swin Transformer model
pruning_ratio = 0.3  # Start with lower pruning ratios
num_training_epochs = 3  # Train longer before pruning
num_finetuning_epochs = 2  # Fine-tune longer to recover
learning_rate = 0.002  # Slightly lower learning rate
learning_rate_finetune = 0.001  # Fine-tuning learning rate
threshold = 0.3  


# Load dataset and apply appropriate transformations
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False)

import torch
from torch import nn
from timm import create_model





    

def count_model_parameters(model: nn.Module) -> int:
    """
    Counts the total number of non-zero parameters in a PyTorch model.

    Args:
        model (nn.Module): The PyTorch model to count the non-zero parameters of.

    Returns:
        int: The total number of non-zero parameters in the model.
    """

    total_non_zero_params = 0
    for param in model.parameters():
        # Count non-zero elements using torch.count_nonzero for efficiency
        total_non_zero_params += torch.count_nonzero(param).item()

    return total_non_zero_params



def prune_model_with_svd_and_iht(model, pruning_ratio, num_iterations=10):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            print("Processing layer:", name)
            weight = module.weight.data
            # Apply SVD
            print("Applying SVD...")
            u, s, v = torch.svd(weight)
            s_pruned = torch.zeros_like(s)
            s_pruned[:int(s.size(0) * (1 - pruning_ratio))] = s[:int(s.size(0) * (1 - pruning_ratio))]
            weight_pruned_svd = torch.mm(u, torch.mm(torch.diag(s_pruned), v.t()))
            num_params_after_svd = count_model_parameters(model)
            print("Number of parameters after SVD pruning:", count_model_parameters(model))

            print(f"Number of parameters after svd pruning: {num_params_after_svd}")
          # Apply IHT with masking for better pruning
            print("Applying IHT...")
            weight_pruned_iht = iht(weights=weight_pruned_svd, pruning_ratio=pruning_ratio,
                                    num_iterations=num_iterations, threshold=threshold, module=module)
            
            # Set the pruned weights and clear gradients (important for training)
            module.weight.data = weight_pruned_iht
            module.weight.grad = None

def iht(weights, pruning_ratio, num_iterations, threshold, module):
    for _ in range(num_iterations):
        # Calculate gradients using torch.autograd.grad with retain_graph=True to avoid recomputing previous gradients
        gradients = torch.autograd.grad(module.weight, module.weight, grad_outputs=weights, retain_graph=True)[0]

        # Update weights using calculated gradients, masking to avoid overwriting pruned values
        mask = torch.abs(weights) >= threshold
        weights[mask] = torch.clamp(weights[mask] - learning_rate * gradients[mask], -threshold, threshold)

    # Apply hard thresholding and clear gradients
    weights[torch.abs(weights) < threshold] = 0
    weights.grad = None
    num_params_after_svdandiht = count_model_parameters(model)
    print(f"Number of parameters after svdandiht pruning: {num_params_after_svdandiht}")

    return weights




def optimizer_creator():
    return torch.optim.SGD(model.parameters(), lr=learning_rate)  # Create optimizer within a closure

# Load the model once before the loop
model = create_model(model_name, pretrained=True)
input_size = (16, 3, 224, 224) 

optimizer = optimizer_creator()
criterion = nn.CrossEntropyLoss()
for epoch in range(num_training_epochs):
    train_loss = 0.0
    train_acc = 0.0
    for images, labels in train_loader:
        optimizer.zero_grad()  # Clear gradients for the current step
        outputs = model(images)
        logits = outputs  # Access logits for Swin Transformers
        loss = criterion(logits, labels)
        loss.backward()  # Backpropagate gradients
        optimizer.step()  # Update model parameters based on gradients
        train_loss += loss.item()  # Accumulate training loss

        
        # Calculate training accuracy
        _, preds = torch.max(outputs, 1)
        train_acc += torch.sum(preds == labels).item() / len(labels)

    # Print training progress
    print(f"Epoch [{epoch+1}/{num_training_epochs}], Pruning Ratio: {pruning_ratio:.2f}, "
            f"Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc/len(train_loader):.4f}")


num_params_before = count_model_parameters(model)
print(f"Number of parameters before pruning: {num_params_before}")

prune_model_with_svd_and_iht(model, pruning_ratio=pruning_ratio)


 # Fine-tune the pruned model
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate_finetune)  # Adjust learning rate
for epoch in range(num_finetuning_epochs):
    train_loss = 0.0
    train_acc = 0.0
    for images, labels in train_loader:
        optimizer.zero_grad()  # Clear gradients for the current step
        outputs = model(images)
        logits = outputs  # Access logits for Swin Transformers
        loss = criterion(logits, labels)
        loss.backward()  # Backpropagate gradients
        optimizer.step()  # Update model parameters based on gradients
        train_loss += loss.item()  # Accumulate training loss

        # Calculate training accuracy
        _, preds = torch.max(outputs, 1)
        train_acc += torch.sum(preds == labels).item() / len(labels)

    # Print training progress
    print(f"Fine-tuning Epoch [{epoch + 1}/{num_finetuning_epochs}], Pruning Ratio: {pruning_ratio:.2f}, "
          f"Train Loss: {train_loss / len(train_loader):.4f}, Train Acc: {train_acc / len(train_loader):.4f}")


# Evaluate pruned model accuracy on the test set
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
    _, predicted = torch.max(outputs.data, 1)  # Corrected indentation
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
pruned_accuracy = 100 * correct / total
print(f"Accuracy after pruning {pruning_ratio:.2f}: {pruned_accuracy:.2f}%")
