In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [3]:
!pip install thop
!pip install timm



In [None]:
from torchvision.datasets import ImageFolder
from torchvision.transforms import Resize, RandomHorizontalFlip, ToTensor, Normalize
from torch.utils.data import DataLoader
from torch.optim import SGD
from torch.nn import CrossEntropyLoss
from timm import create_model
from torch import nn, autograd, max, sum
import torch
from torchvision import transforms
from thop import profile
import os

# Model and pruning configuration (adjust as needed)
model_name = "swin_tiny_patch4_window7_224"  # Choose the desired Swin Transformer model
pruning_ratio = 0.45  # Start with lower pruning ratios
num_training_epochs = 10 # Train longer before pruning
num_finetuning_epochs = 2  # Fine-tune longer to recover
learning_rate = 0.002  # Slightly lower learning rate
learning_rate_finetune = 0.001  # Fine-tuning learning rate
threshold = 0.3  

# Load dataset and apply appropriate transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset_root = "/kaggle/input/caltech-101"

# Load the entire dataset
full_dataset = ImageFolder(root=dataset_root, transform=transform)

# Split the dataset into train and test subsets (e.g., 80% train, 20% test)
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

def get_original_model_size(model: nn.Module, model_path: str = "/tmp/model.pt") -> int:
    """
    Gets the size of the PyTorch model in bytes.

    Args:
        model (nn.Module): The PyTorch model.
        model_path (str, optional): The path to save the model. Defaults to "/tmp/model.pt".

    Returns:
        int: The size of the model in bytes.
    """
    torch.save(model.state_dict(), model_path)
    return os.path.getsize(model_path)


def get_pruned_model_size(model: nn.Module, model_path: str = "/tmp/model.pt") -> int:
    """
    Gets the size of the PyTorch model in bytes, considering pruned weights.

    Args:
        model (nn.Module): The PyTorch model.
        model_path (str, optional): The path to save the model. Defaults to "/tmp/model.pt".

    Returns:
        int: The size of the model in bytes.
    """
    # Save the pruned model state
    torch.save(model.state_dict(), model_path)
    num_elements = 0
    # Iterate over values in the state dict (weights and biases)
    for param in model.state_dict().values():
        num_elements += param.numel()
    # Calculate size assuming 4 bytes per element (common for floats)
    return num_elements * 4


def count_model_parameters(model: nn.Module) -> int:
    """
    Counts the total number of non-zero parameters in a PyTorch model.

    Args:
        model (nn.Module): The PyTorch model to count the non-zero parameters of.

    Returns:
        int: The total number of non-zero parameters in the model.
    """

    total_non_zero_params = 0
    for param in model.parameters():
        # Count non-zero elements using torch.count_nonzero for efficiency
        total_non_zero_params += torch.count_nonzero(param).item()

    return total_non_zero_params


def count_flops(model: nn.Module) -> int:
  """
  Counts the approximate number of floating-point operations (FLOPs) for a PyTorch model,
  considering the sparsity of weights due to pruning.

  Args:
    model (nn.Module): The PyTorch model.

  Returns:
    int: The estimated number of FLOPs.
  """
  total_flops = 0
  for name, module in model.named_modules():
    if isinstance(module, (nn.Conv2d, nn.Linear)):
      if isinstance(module, nn.Conv2d):
        in_features = module.in_channels
        out_features = module.out_channels
        kernel_size = module.kernel_size
      else:
        in_features = module.in_features
        out_features = module.out_features
        kernel_size = (1, 1)  # Kernel size is 1 for Linear layers

      # Consider sparsity by multiplying with the number of non-zero elements in weights
      num_non_zero_weights = torch.count_nonzero(module.weight).item()
      flops_per_output = in_features * out_features * kernel_size[0] * kernel_size[1]
      total_flops += num_non_zero_weights * flops_per_output

  return total_flops


def prune_heads(model, pruning_ratio):
    for name, module in model.named_modules():
        if isinstance(module, nn.MultiheadAttention):
            print("Processing attention layer:", name)
            num_heads = module.num_heads
            head_importance_scores = torch.norm(module.in_proj_weight, dim=0)  # Calculate L2 norm for each head
            num_heads_to_prune = int(num_heads * pruning_ratio)
            _, sorted_head_indices = torch.topk(head_importance_scores, num_heads_to_prune)

            # Prune the least important heads
            module.num_heads -= num_heads_to_prune
            module.in_proj_weight = nn.Parameter(module.in_proj_weight[:, sorted_head_indices])
            module.out_proj.weight = nn.Parameter(module.out_proj.weight[sorted_head_indices])

            print(f"Pruned {num_heads_to_prune} heads. Remaining heads: {module.num_heads}")



def prune_model_with_svd_and_iht(model, pruning_ratio, num_iterations=10):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            print("Processing layer:", name)
            weight = module.weight.data
            # Apply SVD
            print("Applying SVD...")
            u, s, v = torch.svd(weight)
            s_pruned = torch.zeros_like(s)
            s_pruned[:int(s.size(0) * (1 - pruning_ratio))] = s[:int(s.size(0) * (1 - pruning_ratio))]
            weight_pruned_svd = torch.mm(u, torch.mm(torch.diag(s_pruned), v.t()))
            num_params_after_svd = count_model_parameters(model)
            print("Number of parameters after SVD pruning:", count_model_parameters(model))

            print(f"Number of parameters after svd pruning: {num_params_after_svd}")
          # Apply IHT with masking for better pruning
            print("Applying IHT...")
            weight_pruned_iht = iht(weights=weight_pruned_svd, pruning_ratio=pruning_ratio,
                                    num_iterations=num_iterations, threshold=threshold, module=module)
            
            # Set the pruned weights and clear gradients (important for training)
            module.weight.data = weight_pruned_iht
            module.weight.grad = None
           

def iht(weights, pruning_ratio, num_iterations, threshold, module):
    for _ in range(num_iterations):
        # Calculate gradients using torch.autograd.grad with retain_graph=True to avoid recomputing previous gradients
        gradients = torch.autograd.grad(module.weight, module.weight, grad_outputs=weights, retain_graph=True)[0]

        # Update weights using calculated gradients, masking to avoid overwriting pruned values
        mask = torch.abs(weights) >= threshold
        weights[mask] = torch.clamp(weights[mask] - learning_rate * gradients[mask], -threshold, threshold)

    # Apply hard thresholding and clear gradients
    weights[torch.abs(weights) < threshold] = 0
    weights.grad = None
    num_params_after_svdandiht = count_model_parameters(model)
    print(f"Number of parameters after svdandiht pruning: {num_params_after_svdandiht}")

    return weights


# Load the model once before the loop
model = create_model(model_name, pretrained=True)

def optimizer_creator():
    return torch.optim.SGD(model.parameters(), lr=learning_rate)  # Create optimizer within a closure
# Load the model once before the loop
model = create_model(model_name, pretrained=True)
input_size = (16, 3, 224, 224) 



optimizer = optimizer_creator()
criterion = nn.CrossEntropyLoss()
for epoch in range(num_training_epochs):
    train_loss = 0.0
    train_acc = 0.0
    for images, labels in train_loader:
        optimizer.zero_grad()  # Clear gradients for the current step
        outputs = model(images)
        logits = outputs  # Access logits for Swin Transformers
        loss = criterion(logits, labels)
        loss.backward()  # Backpropagate gradients
        optimizer.step()  # Update model parameters based on gradients
        train_loss += loss.item()  # Accumulate training loss

        
        # Calculate training accuracy
        _, preds = torch.max(outputs, 1)
        train_acc += torch.sum(preds == labels).item() / len(labels)

    # Print training progress
    print(f"Epoch [{epoch+1}/{num_training_epochs}], Pruning Ratio: {pruning_ratio:.2f}, "
            f"Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc/len(train_loader):.4f}")


flops_before_pruning = count_flops(model)
print(f"Estimated FLOPs before pruning: {flops_before_pruning:.2e}")

# Print the model size before pruning
print(f"Model size before pruning: {get_original_model_size(model)} bytes")

num_params_before = count_model_parameters(model)
print(f"Number of parameters before pruning: {num_params_before}")

# Count the FLOPs before pruning
flops_before, _ = profile(model, inputs=(torch.randn(1, 3, 224, 224), ))
print(f"Number of FLOPs before pruning: {flops_before:.2f} GFLOPs")
prune_heads(model, pruning_ratio=pruning_ratio)

# Apply pruning with the current pruning ratio
prune_model_with_svd_and_iht(model, pruning_ratio=pruning_ratio)

# Calculate FLOPs after pruning
flops_after_pruning = count_flops(model)
print(f"Estimated FLOPs after pruning: {flops_after_pruning:.2e}")

# Calculate model size after pruning considering sparsity
pruned_model_size_sparse = get_pruned_model_size(model)
print(f"Model size after pruning (considering sparsity): {pruned_model_size_sparse} bytes")

# **Optional:** Save the entire state dictionary after pruning (original functionality)
pruned_model_size_all = get_original_model_size(model)  # Assuming you have this functionality
print(f"Model size after pruning (entire state dict): {pruned_model_size_all} bytes")



Estimated FLOPs before pruning: 3.69e+13
Model size before pruning: 113207542 bytes
Number of parameters before pruning: 28288354
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
Number of FLOPs before pruning: 4371851808.00 GFLOPs
Processing layer: layers.0.blocks.0.attn.qkv
Applying SVD...
Number of parameters after SVD pruning: 28288354
Number of parameters after svd pruning: 28288354
Applying IHT...
Number of parameters after svdandiht pruning: 28288354
Processing layer: layers.0.blocks.0.attn.proj
Applying SVD...
Number of parameters after SVD pruning: 2