In [1]:
!pip install --upgrade pip



Collecting pip
  Downloading pip-24.0-py3-none-any.whl.metadata (3.6 kB)


Downloading pip-24.0-py3-none-any.whl (2.1 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.1 MB[0m [31m?[0m eta [36m-:--:--[0m


[2K   [91m━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.8/2.1 MB[0m [31m23.7 MB/s[0m eta [36m0:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m31.4 MB/s[0m eta [36m0:00:00[0m
[?25h

Installing collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 23.3.2
    Uninstalling pip-23.3.2:


      Successfully uninstalled pip-23.3.2


Successfully installed pip-24.0


In [2]:
!pip install timm













In [3]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [4]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from timm import create_model  # Import for Swin Transformer

# Model and pruning configuration (adjust as needed)
model_name = "swin_tiny_patch4_window7_224"  # Choose the desired Swin Transformer model
pruning_schedule = [0.1, 0.2, 0.3, 0.4]  # Start with lower pruning ratios
num_training_epochs = 10  # Train longer before pruning
num_finetuning_epochs = 20  # Fine-tune longer to recover
learning_rate = 0.003  # Slightly lower learning rate
learning_rate_finetune = 0.001  # Fine-tuning learning rate
threshold = 0.4  # Adjust threshold for IHT if needed

# Load dataset and apply appropriate transformations
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2, shuffle=False)

def prune_model_with_svd_and_iht(model, pruning_ratio, num_iterations=10):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            weight = module.weight.data

            # Apply SVD
            u, s, v = torch.svd(weight)
            s_pruned = torch.zeros_like(s)
            s_pruned[:int(s.size(0) * (1 - pruning_ratio))] = s[:int(s.size(0) * (1 - pruning_ratio))]
            weight_pruned_svd = torch.mm(u, torch.mm(torch.diag(s_pruned), v.t()))

            # Apply IHT
            weight_pruned_iht = iht(weights=weight_pruned_svd, pruning_ratio=pruning_ratio,
                                   num_iterations=num_iterations, threshold=threshold, module=module)

            # Combine pruning methods (e.g., average)
            weight_pruned = (weight_pruned_svd + weight_pruned_iht) / 2

            module.weight.data = weight_pruned  # Apply the pruned weights


def iht(weights, pruning_ratio, num_iterations, threshold, module):
    for _ in range(num_iterations):
        # Calculate gradients using torch.autograd.grad
        gradients = torch.autograd.grad(module.weight, module.weight, grad_outputs=weights)[0]

        # Update weights using calculated gradients
        weights = torch.clamp(weights - learning_rate * gradients, -threshold, threshold)

    return weights

# Load the model once before the loop
model = create_model(model_name, pretrained=True)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)  # Define the optimizer
criterion = nn.CrossEntropyLoss()  # Define the loss function


# Initialize overall accuracy tracking
overall_accuracy = 0.0
num_iterations = 0

for pruning_ratio in pruning_schedule:
    # Train for a few epochs before pruning
    for epoch in range(num_training_epochs):
        train_loss = 0.0
        train_acc = 0.0
        for images, labels in train_loader:
            optimizer.zero_grad()  # Clear gradients for the current step
            outputs = model(images)
            logits = outputs # Access logits for Swin Transformers
            loss = criterion(logits, labels)
            loss.backward()  # Backpropagate gradients
            optimizer.step()  # Update model parameters based on gradients
            train_loss += loss.item()  # Accumulate training loss

            # Calculate training accuracy
            _, preds = torch.max(outputs, 1)
            train_acc += torch.sum(preds == labels).item() / len(labels)

        # Print training progress
        print(f"Epoch [{epoch+1}/{num_training_epochs}], Pruning Ratio: {pruning_ratio:.2f}, "
              f"Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc/len(train_loader):.4f}")

    # Apply pruning with the current pruning ratio
    prune_model_with_svd_and_iht(model, pruning_ratio=pruning_ratio)

    # Fine-tune the pruned model
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate_finetune)  # Adjust learning rate
for epoch in range(num_finetuning_epochs):
    train_loss = 0.0
    train_acc = 0.0
    for images, labels in train_loader:
        optimizer.zero_grad()  # Clear gradients for the current step
        outputs = model(images)
        logits = outputs  # Access logits for Swin Transformers
        loss = criterion(logits, labels)
        loss.backward()  # Backpropagate gradients
        optimizer.step()  # Update model parameters based on gradients
        train_loss += loss.item()  # Accumulate training loss

        # Calculate training accuracy
        _, preds = torch.max(outputs, 1)
        train_acc += torch.sum(preds == labels).item() / len(labels)

    # Print training progress
    print(f"Fine-tuning Epoch [{epoch + 1}/{num_finetuning_epochs}], Pruning Ratio: {pruning_ratio:.2f}, "
          f"Train Loss: {train_loss / len(train_loader):.4f}, Train Acc: {train_acc / len(train_loader):.4f}")


    # Evaluate pruned model accuracy on the test set
correct = 0
total = 0
with torch.no_grad():
  for images, labels in test_loader:
    outputs = model(images)
    _, predicted = torch.max(outputs.logits.data, 1)  # Corrected indentation
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
pruned_accuracy = 100 * correct / total
print(f"Accuracy after pruning {pruning_ratio:.2f}: {pruned_accuracy:.2f}%")


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz



  0%|          | 0/170498071 [00:00<?, ?it/s]


  0%|          | 65536/170498071 [00:00<04:44, 598226.52it/s]


  0%|          | 229376/170498071 [00:00<02:28, 1146432.83it/s]


  1%|          | 917504/170498071 [00:00<00:48, 3496321.33it/s]


  2%|▏         | 3670016/170498071 [00:00<00:13, 12023151.98it/s]


  5%|▍         | 7766016/170498071 [00:00<00:07, 21154486.06it/s]


  7%|▋         | 11796480/170498071 [00:00<00:05, 27265894.93it/s]


  9%|▉         | 15532032/170498071 [00:00<00:05, 30421103.86it/s]


 11%|█         | 19136512/170498071 [00:00<00:04, 31994926.92it/s]


 13%|█▎        | 22872064/170498071 [00:00<00:04, 33611812.66it/s]


 16%|█▌        | 26509312/170498071 [00:01<00:04, 34447053.75it/s]


 18%|█▊        | 30343168/170498071 [00:01<00:03, 35578871.99it/s]


 20%|██        | 34340864/170498071 [00:01<00:03, 36157933.59it/s]


 23%|██▎       | 38371328/170498071 [00:01<00:03, 36444862.65it/s]


 25%|██▍       | 42401792/170498071 [00:01<00:03, 36641987.41it/s]


 27%|██▋       | 46530560/170498071 [00:01<00:03, 36977141.24it/s]


 30%|██▉       | 50659328/170498071 [00:01<00:03, 37168094.48it/s]


 32%|███▏      | 54755328/170498071 [00:01<00:03, 37534347.09it/s]


 34%|███▍      | 58785792/170498071 [00:01<00:02, 38291823.09it/s]


 37%|███▋      | 62619648/170498071 [00:02<00:02, 37927266.76it/s]


 39%|███▉      | 66420736/170498071 [00:02<00:02, 37398883.52it/s]


 41%|████      | 70189056/170498071 [00:02<00:02, 37470874.14it/s]


 43%|████▎     | 73957376/170498071 [00:02<00:02, 37153685.84it/s]


 46%|████▌     | 77692928/170498071 [00:02<00:02, 36961876.75it/s]


 48%|████▊     | 81690624/170498071 [00:02<00:02, 37727716.49it/s]


 50%|█████     | 85557248/170498071 [00:02<00:02, 37806460.12it/s]


 52%|█████▏    | 89456640/170498071 [00:02<00:02, 38125975.18it/s]


 55%|█████▍    | 93290496/170498071 [00:02<00:02, 37636367.33it/s]


 57%|█████▋    | 97058816/170498071 [00:02<00:01, 37229276.02it/s]


 59%|█████▉    | 100859904/170498071 [00:03<00:01, 37444502.54it/s]


 61%|██████▏   | 104628224/170498071 [00:03<00:01, 37020270.01it/s]


 64%|██████▎   | 108494848/170498071 [00:03<00:01, 37284271.91it/s]


 66%|██████▌   | 112394240/170498071 [00:03<00:01, 37648518.90it/s]


 68%|██████▊   | 116326400/170498071 [00:03<00:01, 38129252.29it/s]


 70%|███████   | 120160256/170498071 [00:03<00:01, 37913228.43it/s]


 73%|███████▎  | 123961344/170498071 [00:03<00:01, 37362655.18it/s]


 75%|███████▍  | 127729664/170498071 [00:03<00:01, 37333639.10it/s]


 77%|███████▋  | 131465216/170498071 [00:03<00:01, 37211635.61it/s]


 79%|███████▉  | 135200768/170498071 [00:03<00:00, 37082092.87it/s]


 82%|████████▏ | 139132928/170498071 [00:04<00:00, 37694792.42it/s]


 84%|████████▍ | 142999552/170498071 [00:04<00:00, 37619123.62it/s]


 86%|████████▌ | 146833408/170498071 [00:04<00:00, 37690527.30it/s]


 88%|████████▊ | 150634496/170498071 [00:04<00:00, 37719671.01it/s]


 91%|█████████ | 154435584/170498071 [00:04<00:00, 37405287.06it/s]


 93%|█████████▎| 158203904/170498071 [00:04<00:00, 37395334.54it/s]


 95%|█████████▍| 161972224/170498071 [00:04<00:00, 37263285.53it/s]


 97%|█████████▋| 165740544/170498071 [00:04<00:00, 36773712.95it/s]


100%|█████████▉| 169771008/170498071 [00:04<00:00, 37515161.14it/s]


100%|██████████| 170498071/170498071 [00:04<00:00, 34895954.72it/s]




Extracting ./data/cifar-10-python.tar.gz to ./data


Files already downloaded and verified


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

Epoch [1/10], Pruning Ratio: 0.10, Train Loss: 0.5344, Train Acc: 0.8258


Epoch [2/10], Pruning Ratio: 0.10, Train Loss: 0.1451, Train Acc: 0.9502


Epoch [3/10], Pruning Ratio: 0.10, Train Loss: 0.0985, Train Acc: 0.9664


Epoch [4/10], Pruning Ratio: 0.10, Train Loss: 0.0749, Train Acc: 0.9747
