In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import pandas as pd
import csv

from models import *
from models.vit import ViT, channel_selection
from models.vit_slim import ViT_slim
from utils import progress_bar

In [17]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cudnn.benchmark = True

In [18]:
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


In [19]:
transform_test

Compose(
    ToTensor()
    Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.201))
)

In [20]:
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)


In [21]:
model = ViT(
    image_size = 32,
    patch_size = 4,
    num_classes = 10,
    dim = 512,                  # 512
    depth = 6,
    heads = 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1
    )
model = model.to(device)

In [22]:
model_path = "checkpoint/vit-4-ckpt_512.t7"
print("=> loading checkpoint '{}'".format(model_path))
checkpoint = torch.load(model_path)
start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['acc']
model.load_state_dict(checkpoint['net'])
print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}".format(model_path, checkpoint['epoch'], best_prec1))


=> loading checkpoint 'checkpoint/vit-4-ckpt_512.t7'
=> loaded checkpoint 'checkpoint/vit-4-ckpt_512.t7' (epoch 40) Prec1: 80.350000


In [23]:
def test(model):
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        print('Acc: %.3f%% (%d/%d)' % (100.*correct/total, correct, total))

    
    

Pruning Experiments

In [24]:
import psutil
import time
import os
from thop import profile
def calculate_accuracy(model, dataloader):
    correct = total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            # print(images.shape)
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

# --- 2. Inference Speed and Latency ---
def benchmark_inference(model, input_shape=(1, 3, 32, 32), runs=100):
    dummy_input = torch.randn(input_shape).to(device)
    torch.cuda.synchronize()
    start = time.time()
    for i in range(runs):
        output = model(dummy_input)
    torch.cuda.synchronize()
    total_time = time.time() - start
    latency = total_time / runs
    throughput = runs / total_time
    return latency, throughput

# --- 3. Model Size ---
def get_model_size(model, temp_path='temp.pth'):
    torch.save(model.state_dict(), temp_path)
    size_mb = os.path.getsize(temp_path) / 1e6
    os.remove(temp_path)
    return size_mb

# --- 4. Memory Usage (estimated by RAM during execution) ---
def get_memory_usage():
    process = psutil.Process(os.getpid())
    mem = process.memory_info().rss / 1e6  # in MB
    return mem

# --- 5. FLOPs and Parameters ---
def get_flops(model, input_shape=(1, 3, 32, 32)):
    dummy_input = torch.randn(input_shape).to(device)
    flops, params = profile(model, inputs=(dummy_input,), verbose=False)
    return flops / 1e9, params / 1e6  # GFLOPs and MParams

# --- 6. Estimate Power Usage ---
def estimate_power(flops, latency):
    # Rough estimate: 1 GFLOP = ~0.1 Watt-sec (example heuristic)
    energy = flops * 0.1  # Watt-seconds
    power = energy / latency  # Watts
    return power

def compute_metrics(model,testloader):
    # --- Run all metrics ---
    accuracy = calculate_accuracy(model, testloader)
    latency, speed = benchmark_inference(model)
    model_size = get_model_size(model)
    mem_usage = get_memory_usage()
    flops, params = get_flops(model)
    power = estimate_power(flops, latency)


    # --- Print results ---
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Inference Latency: {latency*1000:.2f} ms")
    print(f"Inference Speed: {speed:.2f} samples/sec")
    print(f"Model Size: {model_size:.2f} MB")
    print(f"Memory Usage (runtime): {mem_usage:.2f} MB")
    print(f"FLOPs: {flops:.2f} GFLOPs")
    print(f"Parameters: {params:.2f} Million")
    print(f"Estimated Power: {power:.2f} Watts")
    

In [25]:
def get_prune_model(model,prune_percent):
    total = 0
    for m in model.modules():
        if isinstance(m, channel_selection):
            total += m.indexes.data.shape[0]
    
    bn = torch.zeros(total)
    index = 0
    for m in model.modules():
        if isinstance(m, channel_selection):
            size = m.indexes.data.shape[0]
            bn[index:(index+size)] = m.indexes.data.abs().clone()
            index += size
    
    y, i = torch.sort(bn)
    thre_index = int(total * prune_percent)
    thre = y[thre_index]

    pruned = 0
    cfg = []
    cfg_mask = []
    for k, m in enumerate(model.modules()):
        if isinstance(m, channel_selection):
            # print(k)
            # print(m)
            if k in [16,40,64,88,112,136]:
                weight_copy = m.indexes.data.abs().clone()
                mask = weight_copy.gt(thre).float().cuda()
                thre_ = thre.clone()
                while (torch.sum(mask)%8 !=0):                       # heads
                    thre_ = thre_ - 0.0001
                    mask = weight_copy.gt(thre_).float().cuda()
            else:
                weight_copy = m.indexes.data.abs().clone()
                mask = weight_copy.gt(thre).float().cuda()
            pruned = pruned + mask.shape[0] - torch.sum(mask)
            m.indexes.data.mul_(mask)
            cfg.append(int(torch.sum(mask)))
            cfg_mask.append(mask.clone())
            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
                format(k, mask.shape[0], int(torch.sum(mask))))
    pruned_ratio = pruned/total
    print('Pre-processing Successful!')
    print("Pruned Ratio:",pruned_ratio)
    print(cfg)

    test(model)
    cfg_prune = []
    for i in range(len(cfg)):
        if i%2!=0:
            cfg_prune.append([cfg[i-1],cfg[i]])
    print(cfg_prune)

    newmodel = ViT_slim(image_size = 32,
    patch_size = 4,
    num_classes = 10,
    dim = 512,
    depth = 6,
    heads = 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1,
    cfg=cfg_prune)
    newmodel.to(device)
    # num_parameters = sum([param.nelement() for param in newmodel.parameters()])

    newmodel_dict = newmodel.state_dict().copy()

    i = 0
    newdict = {}
    for k,v in model.state_dict().items():
        if 'net1.0.weight' in k:
            # print(k)
            # print(v.size())
            # print('----------')
            idx = np.squeeze(np.argwhere(np.asarray(cfg_mask[i].cpu().numpy())))
            newdict[k] = v[idx.tolist()].clone()
        elif 'net1.0.bias' in k:
            # print(k)
            # print(v.size())
            # print('----------')
            idx = np.squeeze(np.argwhere(np.asarray(cfg_mask[i].cpu().numpy())))
            newdict[k] = v[idx.tolist()].clone()
        elif 'to_q' in k or 'to_k' in k or 'to_v' in k:
            # print(k)
            # print(v.size())
            # print('----------')
            idx = np.squeeze(np.argwhere(np.asarray(cfg_mask[i].cpu().numpy())))
            newdict[k] = v[idx.tolist()].clone()
        elif 'net2.0.weight' in k:
            # print(k)
            # print(v.size())
            # print('----------')
            idx = np.squeeze(np.argwhere(np.asarray(cfg_mask[i].cpu().numpy())))
            newdict[k] = v[:,idx.tolist()].clone()
            i = i + 1
        elif 'to_out.0.weight' in k:
            # print(k)
            # print(v.size())
            # print('----------')
            idx = np.squeeze(np.argwhere(np.asarray(cfg_mask[i].cpu().numpy())))
            newdict[k] = v[:,idx.tolist()].clone()
            i = i + 1

        elif k in newmodel.state_dict():
            newdict[k] = v

    newmodel_dict.update(newdict)
    newmodel.load_state_dict(newmodel_dict)

    return newmodel

In [58]:
compute_metrics(model,testloader)

Accuracy: 0.7833
Inference Latency: 2.17 ms
Inference Speed: 461.39 samples/sec
Model Size: 39.20 MB
Memory Usage (runtime): 1546.10 MB
FLOPs: 0.62 GFLOPs
Parameters: 9.75 Million
Estimated Power: 28.46 Watts


In [59]:
prune_20 = get_prune_model(model,0.2)

layer index: 16 	 total channel: 512 	 remaining channel: 504
layer index: 28 	 total channel: 512 	 remaining channel: 488
layer index: 40 	 total channel: 512 	 remaining channel: 512
layer index: 52 	 total channel: 512 	 remaining channel: 418
layer index: 64 	 total channel: 512 	 remaining channel: 512
layer index: 76 	 total channel: 512 	 remaining channel: 310
layer index: 88 	 total channel: 512 	 remaining channel: 512
layer index: 100 	 total channel: 512 	 remaining channel: 198
layer index: 112 	 total channel: 512 	 remaining channel: 512
layer index: 124 	 total channel: 512 	 remaining channel: 197
layer index: 136 	 total channel: 512 	 remaining channel: 512
layer index: 148 	 total channel: 512 	 remaining channel: 251
Pre-processing Successful!
Pruned Ratio: tensor(0.1982, device='cuda:0')
[504, 488, 512, 418, 512, 310, 512, 198, 512, 197, 512, 251]
Acc: 79.230% (7923/10000)
[[504, 488], [512, 418], [512, 310], [512, 198], [512, 197], [512, 251]]


In [60]:
compute_metrics(prune_20,testloader)

Accuracy: 0.7150
Inference Latency: 1.99 ms
Inference Speed: 501.33 samples/sec
Model Size: 34.15 MB
Memory Usage (runtime): 1546.28 MB
FLOPs: 0.54 GFLOPs
Parameters: 8.50 Million
Estimated Power: 26.83 Watts


In [61]:
prune_30 = get_prune_model(model,0.3)

layer index: 16 	 total channel: 512 	 remaining channel: 472
layer index: 28 	 total channel: 512 	 remaining channel: 432
layer index: 40 	 total channel: 512 	 remaining channel: 496
layer index: 52 	 total channel: 512 	 remaining channel: 335
layer index: 64 	 total channel: 512 	 remaining channel: 512
layer index: 76 	 total channel: 512 	 remaining channel: 202
layer index: 88 	 total channel: 512 	 remaining channel: 512
layer index: 100 	 total channel: 512 	 remaining channel: 107
layer index: 112 	 total channel: 512 	 remaining channel: 512
layer index: 124 	 total channel: 512 	 remaining channel: 86
layer index: 136 	 total channel: 512 	 remaining channel: 512
layer index: 148 	 total channel: 512 	 remaining channel: 148
Pre-processing Successful!
Pruned Ratio: tensor(0.2959, device='cuda:0')
[472, 432, 496, 335, 512, 202, 512, 107, 512, 86, 512, 148]
Acc: 77.960% (7796/10000)
[[472, 432], [496, 335], [512, 202], [512, 107], [512, 86], [512, 148]]


In [62]:
compute_metrics(prune_30,testloader)

Accuracy: 0.7093
Inference Latency: 1.95 ms
Inference Speed: 511.57 samples/sec
Model Size: 31.49 MB
Memory Usage (runtime): 1559.09 MB
FLOPs: 0.49 GFLOPs
Parameters: 7.83 Million
Estimated Power: 25.18 Watts


In [14]:
prune_40 = get_prune_model(model,0.4)

layer index: 16 	 total channel: 512 	 remaining channel: 416
layer index: 28 	 total channel: 512 	 remaining channel: 350
layer index: 40 	 total channel: 512 	 remaining channel: 448
layer index: 52 	 total channel: 512 	 remaining channel: 251
layer index: 64 	 total channel: 512 	 remaining channel: 488
layer index: 76 	 total channel: 512 	 remaining channel: 127
layer index: 88 	 total channel: 512 	 remaining channel: 488
layer index: 100 	 total channel: 512 	 remaining channel: 42
layer index: 112 	 total channel: 512 	 remaining channel: 496
layer index: 124 	 total channel: 512 	 remaining channel: 29
layer index: 136 	 total channel: 512 	 remaining channel: 504
layer index: 148 	 total channel: 512 	 remaining channel: 69
Pre-processing Successful!
Pruned Ratio: tensor(0.3965, device='cuda:0')
[416, 350, 448, 251, 488, 127, 488, 42, 496, 29, 504, 69]
Acc: 75.850% (7585/10000)
[[416, 350], [448, 251], [488, 127], [488, 42], [496, 29], [504, 69]]


In [15]:
compute_metrics(prune_40,testloader)

Accuracy: 0.6941
Inference Latency: 2.20 ms
Inference Speed: 453.77 samples/sec
Model Size: 28.24 MB
Memory Usage (runtime): 1136.01 MB
FLOPs: 0.44 GFLOPs
Parameters: 7.02 Million
Estimated Power: 19.93 Watts


Knowledge distillation experiments

In [26]:
# Define distillation loss
class DistillLoss(nn.Module):
    def __init__(self, temperature=4.0, alpha=0.7):
        super().__init__()
        self.T = temperature
        self.alpha = alpha
        self.kl_div = nn.KLDivLoss(reduction="batchmean")
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, targets):
        distill = self.kl_div(
            F.log_softmax(student_logits / self.T, dim=1),
            F.softmax(teacher_logits / self.T, dim=1),
        ) * (self.T ** 2)

        ce = self.ce_loss(student_logits, targets)
        return self.alpha * distill + (1 - self.alpha) * ce


In [27]:
from warmup_scheduler import GradualWarmupScheduler
import albumentations

In [28]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=8)


In [29]:
teacher = ViT(
    image_size=32, 
    patch_size=4, 
    num_classes=10, 
    dim=512,
    depth=6, 
    heads=8, 
    mlp_dim=512, 
    dropout=0.1, 
    emb_dropout=0.1
)

In [30]:
teacher.load_state_dict(torch.load("checkpoint/vit-4-ckpt_512.t7")["net"])
teacher.eval()

ViT(
  (patch_to_embedding): Linear(in_features=48, out_features=512, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (transformer): Transformer(
    (layers): ModuleList(
      (0-5): 6 x ModuleList(
        (0): Residual(
          (fn): PreNorm(
            (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (fn): Attention(
              (to_q): Linear(in_features=512, out_features=512, bias=False)
              (to_k): Linear(in_features=512, out_features=512, bias=False)
              (to_v): Linear(in_features=512, out_features=512, bias=False)
              (to_out): Sequential(
                (0): Linear(in_features=512, out_features=512, bias=True)
                (1): Dropout(p=0.1, inplace=False)
              )
              (select1): channel_selection()
            )
          )
        )
        (1): Residual(
          (fn): PreNorm(
            (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (fn): FeedForward(
 

In [73]:
# Perform knowledge distillation on prune_40
# Train with distillation
student = prune_40
device = 'cuda' if torch.cuda.is_available() else 'cpu'
teacher = teacher.to(device)
student = student.to(device)

In [74]:
criterion = DistillLoss()
optimizer = torch.optim.Adam(student.parameters(), lr=1e-4)

In [75]:
from torch.optim import lr_scheduler
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, verbose=True, min_lr=1e-3*1e-5, factor=0.1)



In [76]:
def train(epoch):
    print('\nEpoch: %d' % epoch)
    student.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        with torch.no_grad():
            teacher_outputs = teacher(inputs)
        student_outputs = student(inputs)
        loss = criterion(student_outputs, teacher_outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = student_outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    return train_loss/(batch_idx+1)

In [77]:
best_acc = 0

In [None]:
def test(epoch):
    global best_acc
    student.eval()
    test_loss = 0
    correct = 0
    total = 0
    criterion_ce = nn.CrossEntropyLoss()
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = student(inputs)
            loss = criterion_ce(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
    
    # Update scheduler
    scheduler.step(test_loss)
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': student.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/'+'vit-4-'+'ckpt_pruned_student.t7')
        best_acc = acc
    
    os.makedirs("log", exist_ok=True)
    content = time.ctime() + ' ' + f'Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, val loss: {test_loss:.5f}, acc: {(acc):.5f}'
    print(content)
    return test_loss, acc

In [79]:
list_loss = []
list_acc = []
for epoch in range(0,50):
    trainloss = train(epoch)
    val_loss, acc = test(epoch)
    
    list_loss.append(val_loss)
    list_acc.append(acc)


Epoch: 0
Saving..
Wed Apr  9 16:59:21 2025 Epoch 0, lr: 0.0001000, val loss: 72.90961, acc: 77.84000

Epoch: 1
Saving..
Wed Apr  9 16:59:58 2025 Epoch 1, lr: 0.0001000, val loss: 69.30963, acc: 78.40000

Epoch: 2
Saving..
Wed Apr  9 17:00:35 2025 Epoch 2, lr: 0.0001000, val loss: 70.77992, acc: 78.54000

Epoch: 3
Saving..
Wed Apr  9 17:01:13 2025 Epoch 3, lr: 0.0001000, val loss: 69.58380, acc: 78.64000

Epoch: 4
Saving..
Wed Apr  9 17:01:51 2025 Epoch 4, lr: 0.0001000, val loss: 67.64480, acc: 79.14000

Epoch: 5
Saving..
Wed Apr  9 17:02:26 2025 Epoch 5, lr: 0.0001000, val loss: 66.15179, acc: 79.52000

Epoch: 6
Wed Apr  9 17:03:01 2025 Epoch 6, lr: 0.0001000, val loss: 67.30012, acc: 79.25000

Epoch: 7
Wed Apr  9 17:03:37 2025 Epoch 7, lr: 0.0001000, val loss: 68.80582, acc: 78.66000

Epoch: 8
Saving..
Wed Apr  9 17:04:14 2025 Epoch 8, lr: 0.0001000, val loss: 64.70294, acc: 79.96000

Epoch: 9
Wed Apr  9 17:04:53 2025 Epoch 9, lr: 0.0001000, val loss: 66.97563, acc: 79.13000

Epoch:

In [91]:
cfg_prune = [[416, 350], [448, 251], [488, 127], [488, 42], [496, 29], [504, 69]]
import copy
pruned_student_model = copy.deepcopy(prune_40)

pruned_student_model = pruned_student_model.to(device)

In [92]:
model_path = "checkpoint/vit-4-ckpt_pruned_student.t7"
print("=> loading checkpoint '{}'".format(model_path))
checkpoint = torch.load(model_path)
start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['acc']
pruned_student_model.load_state_dict(checkpoint['net'])
print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}".format(model_path, checkpoint['epoch'], best_prec1))

=> loading checkpoint 'checkpoint/vit-4-ckpt_pruned_student.t7'
=> loaded checkpoint 'checkpoint/vit-4-ckpt_pruned_student.t7' (epoch 21) Prec1: 80.800000


In [93]:
compute_metrics(student,testloader)

Accuracy: 0.8067
Inference Latency: 1.88 ms
Inference Speed: 531.63 samples/sec
Model Size: 28.27 MB
Memory Usage (runtime): 1627.34 MB
FLOPs: 0.44 GFLOPs
Parameters: 7.02 Million
Estimated Power: 23.35 Watts


In [1]:
# # Perform knowledge distillation on prune_40
# # Train with distillation
# student = ViT(
#     image_size=32, 
#     patch_size=4, 
#     num_classes=10, 
#     dim=128,
#     depth=6, 
#     heads=4, 
#     mlp_dim=128, 
#     dropout=0.1, 
#     emb_dropout=0.1
# )
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# teacher = teacher.to(device)
# student = student.to(device)
# from torch.optim import lr_scheduler
# criterion = DistillLoss()
# optimizer = torch.optim.Adam(student.parameters(), lr=1e-4)
# scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, verbose=True, min_lr=1e-3*1e-5, factor=0.1)
# best_acc = 0

In [None]:
# baseline VIT
# Accuracy: 0.7833
# Inference Latency: 2.17 ms
# Inference Speed: 461.39 samples/sec
# Model Size: 39.20 MB
# Memory Usage (runtime): 1546.10 MB
# FLOPs: 0.62 GFLOPs
# Parameters: 9.75 Million
# Estimated Power: 28.46 Watts


# prune 20%
# Accuracy: 0.7150
# Inference Latency: 1.99 ms
# Inference Speed: 501.33 samples/sec
# Model Size: 34.15 MB
# Memory Usage (runtime): 1546.28 MB
# FLOPs: 0.54 GFLOPs
# Parameters: 8.50 Million
# Estimated Power: 26.83 Watts

# prune 30%
# Accuracy: 0.7093
# Inference Latency: 1.95 ms
# Inference Speed: 511.57 samples/sec
# Model Size: 31.49 MB
# Memory Usage (runtime): 1559.09 MB
# FLOPs: 0.49 GFLOPs
# Parameters: 7.83 Million
# Estimated Power: 25.18 Watts

# prune 40%
# Accuracy: 0.6925
# Inference Latency: 1.94 ms
# Inference Speed: 516.20 samples/sec
# Model Size: 28.24 MB
# Memory Usage (runtime): 1625.50 MB
# FLOPs: 0.44 GFLOPs
# Parameters: 7.02 Million
# Estimated Power: 22.68 Watts

# prune 40% knowledge distillation fine tuning
# Accuracy: 0.8067
# Inference Latency: 1.88 ms
# Inference Speed: 531.63 samples/sec
# Model Size: 28.27 MB
# Memory Usage (runtime): 1627.34 MB
# FLOPs: 0.44 GFLOPs
# Parameters: 7.02 Million
# Estimated Power: 18.94 Watts 

# Teacher-> student (reduced dim,depth,head) Knowledge distillation
# Accuracy: 0.7764
# Inference Latency: 1.32 ms
# Inference Speed: 755.03 samples/sec
# Model Size: 6.75 MB
# Memory Usage (runtime): 1290.65 MB
# FLOPs: 0.10 GFLOPs
# Parameters: 1.66 Million
# Estimated Power: 7.82 Watts





# prune 40% knowledge distillation fine tuning
# Accuracy: 0.8067
# Inference Latency: 2.32 ms
# Inference Speed: 431.19 samples/sec
# Model Size: 28.27 MB
# Memory Usage (runtime): 1625.94 MB
# FLOPs: 0.44 GFLOPs
# Parameters: 7.02 Million
# Estimated Power: 18.94 Watts


# ViT(
#     image_size=32, 
#     patch_size=4, 
#     num_classes=10, 
#     dim=128,
#     depth=6, 
#     heads=4, 
#     mlp_dim=128, 
#     dropout=0.1, 
#     emb_dropout=0.1
# )
# Accuracy: 0.7032
# Inference Latency: 2.19 ms
# Inference Speed: 457.61 samples/sec
# Model Size: 2.55 MB
# Memory Usage (runtime): 1289.88 MB
# FLOPs: 0.04 GFLOPs
# Parameters: 0.62 Million
# Estimated Power: 1.79 Watts

# student = ViT(
#     image_size=32, 
#     patch_size=4, 
#     num_classes=10, 
#     dim=128,
#     depth=6, 
#     heads=4, 
#     mlp_dim=128, 
#     dropout=0.1, 
#     emb_dropout=0.1
# )
# Accuracy: 0.6195
# Inference Latency: 1.54 ms
# Inference Speed: 649.42 samples/sec
# Model Size: 1.75 MB
# Memory Usage (runtime): 1292.56 MB
# FLOPs: 0.03 GFLOPs
# Parameters: 0.42 Million
# Estimated Power: 1.70 Watts


# student = ViT(
#     image_size=32, 
#     patch_size=4, 
#     num_classes=10, 
#     dim=128,
#     depth=6, 
#     heads=4, 
#     mlp_dim=128, 
#     dropout=0.1, 
#     emb_dropout=0.1
# )
# Accuracy: 0.7631
# Inference Latency: 2.20 ms
# Inference Speed: 454.32 samples/sec
# Model Size: 2.55 MB
# Memory Usage (runtime): 1291.46 MB
# FLOPs: 0.04 GFLOPs
# Parameters: 0.62 Million
# Estimated Power: 1.78 Watts


# student = ViT(
#     image_size=32, 
#     patch_size=4, 
#     num_classes=10, 
#     dim=256,
#     depth=4, 
#     heads=4, 
#     mlp_dim=256, 
#     dropout=0.1, 
#     emb_dropout=0.1
# )


# Accuracy: 0.7764
# Inference Latency: 1.32 ms
# Inference Speed: 755.03 samples/sec
# Model Size: 6.75 MB
# Memory Usage (runtime): 1290.65 MB
# FLOPs: 0.10 GFLOPs
# Parameters: 1.66 Million
# Estimated Power: 7.82 Watts

### 📊 ViT Performance Comparison: Baseline, Pruned, and Knowledge Distillation Variants

| Model                                              | Accuracy | Latency (ms) | Speed (samples/sec) | Model Size (MB) | Memory (MB) | FLOPs (GFLOPs) | Parameters (M) | Power (W) |
|---------------------------------------------------|----------|--------------|----------------------|------------------|-------------|----------------|----------------|-----------|
| **Baseline ViT**                                  | 78.33%   | 2.17         | 461.39               | 39.20            | 1546.10     | 0.62           | 9.75           | 28.46     |
| **Pruned 20%**                                     | 71.50%   | 1.99         | 501.33               | 34.15            | 1546.28     | 0.54           | 8.50           | 26.83     |
| **Pruned 30%**                                     | 70.93%   | 1.95         | 511.57               | 31.49            | 1559.09     | 0.49           | 7.83           | 25.18     |
| **Pruned 40%**                                     | 69.25%   | 1.94         | 516.20               | 28.24            | 1625.50     | 0.44           | 7.02           | 22.68     |
| **Pruned 40% + KD Fine-tuned**                    | 80.67%   | 1.88         | 531.63               | 28.27            | 1627.34     | 0.44           | 7.02           | 18.94     |
| **KD: Reduced Dim + Depth + Heads**    | 77.64%   | 1.32         | 755.03               | 6.75             | 1290.65     | 0.10           | 1.66           | 7.82      |

---

### ✅ Observations:
- 📈 **KD Fine-tuned Pruned model** surpasses baseline in accuracy with significantly lower compute cost.
- 🔄 **Progressive pruning** reduces model size and power usage at the cost of accuracy.
- ⚡ **Tiny ViT with KD** offers the best trade-off for real-time low-power environments.
- 🧠 **Knowledge Distillation** is highly effective across both pruned and compact models.


### 📉 Percentage Change Compared to Baseline ViT

| Model                                              | Accuracy Δ | Latency Δ | Speed Δ         | Model Size Δ | Memory Δ    | FLOPs Δ     | Params Δ     | Power Δ     |
|---------------------------------------------------|------------|-----------|------------------|---------------|--------------|--------------|---------------|-------------|
| **Baseline ViT**                                  | –          | –         | –                | –             | –            | –            | –             | –           |
| **Pruned 20%**                                     | 🔽 -8.71%  | 🔽 -8.29% | 🔼 +8.66%        | 🔽 -12.88%    | 🔼 +0.01%    | 🔽 -12.90%   | 🔽 -12.82%    | 🔽 -5.73%    |
| **Pruned 30%**                                     | 🔽 -9.44%  | 🔽 -10.14%| 🔼 +10.88%       | 🔽 -19.67%    | 🔼 +0.84%    | 🔽 -20.97%   | 🔽 -19.69%    | 🔽 -11.52%   |
| **Pruned 40%**                                     | 🔽 -11.61% | 🔽 -10.60%| 🔼 +11.88%       | 🔽 -27.96%    | 🔼 +5.14%    | 🔽 -29.03%   | 🔽 -27.90%    | 🔽 -20.32%   |
| **Pruned 40% + KD Fine-tuned**                    | 🔼 +2.99%  | 🔽 -13.36%| 🔼 +15.23%       | 🔽 -27.89%    | 🔼 +5.25%    | 🔽 -29.03%   | 🔽 -27.90%    | 🔽 -33.45%   |
| **KD: Reduced Dim + Depth + Heads**    | 🔽 -0.88%  | 🔽 -39.17%| 🔼 +63.68%       | 🔽 -82.78%    | 🔽 -16.51%   | 🔽 -83.87%   | 🔽 -82.97%    | 🔽 -72.51%   |

---

### 📌 Notes:
- **Teacher student with KD** gives massive gains in speed and size at minimal accuracy loss (only -0.88%).
- **Pruned + KD model** exceeds baseline accuracy (+2.99%) while cutting **power by ~33%** and **size by ~28%**.
- Every **pruned variant** consistently improves latency, throughput, and efficiency.
