In [None]:
!pip3 install torch torchvision torchaudio

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
Collecting nvidia-curand-cu12==10.3.2.106 (from torch)
  Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)
Collectin

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Define VGG-16-BN model

In [None]:
import time
import torch
import torch.nn as nn
from collections import OrderedDict

defaultcfg = [
    64,
    64,
    "M",
    128,
    128,
    "M",
    256,
    256,
    256,
    "M",
    512,
    512,
    512,
    "M",
    512,
    512,
    512,
]


class VGG(nn.Module):
    def __init__(self, compress_rate=[0.0] * 13, cfg=None, num_classes=10):
        super(VGG, self).__init__()

        if cfg is None:
            cfg = defaultcfg

        self.compress_rate = compress_rate[:]

        self.features = self._make_layers(cfg)
        last_conv_out_channels = self.features[-3].out_channels
        self.classifier = nn.Sequential(
            OrderedDict(
                [
                    ("linear1", nn.Linear(last_conv_out_channels, cfg[-1])),
                    ("norm1", nn.BatchNorm1d(cfg[-1])),
                    ("relu1", nn.ReLU(inplace=True)),
                    ("linear2", nn.Linear(cfg[-1], num_classes)),
                ]
            )
        )

    def _make_layers(self, cfg):
        layers = nn.Sequential()
        in_channels = 3
        cnt = 0

        for i, x in enumerate(cfg):
            if x == "M":
                layers.add_module("pool%d" % i, nn.MaxPool2d(kernel_size=2, stride=2))
            else:
                x = int(x * (1 - self.compress_rate[cnt]))
                cnt += 1
                conv2d = nn.Conv2d(in_channels, x, kernel_size=3, padding=1)
                layers.add_module("conv%d" % i, conv2d)
                layers.add_module("norm%d" % i, nn.BatchNorm2d(x))
                layers.add_module("relu%d" % i, nn.ReLU(inplace=True))
                in_channels = x

        return layers

    def forward(self, x):
        x = self.features(x)
        x = nn.AvgPool2d(2)(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


def vgg_16_bn(compress_rate=[0.0] * 13):
    return VGG(compress_rate=compress_rate)

Helper functions

In [None]:
import re

def get_cpr(compress_rate):
    cprate_str = compress_rate
    cprate_str_list = cprate_str.split("+")
    pat_cprate = re.compile(r"\d+\.\d*")
    pat_num = re.compile(r"\*\d+")
    cprate = []
    for x in cprate_str_list:
        num = 1
        find_num = re.findall(pat_num, x)
        if find_num:
            assert len(find_num) == 1
            num = int(find_num[0].replace("*", ""))
        find_cprate = re.findall(pat_cprate, x)
        assert len(find_cprate) == 1
        cprate += [float(find_cprate[0])] * num

    return cprate

In [None]:

import os
import sys
import shutil
import time, datetime
import logging
import numpy as np
from PIL import Image
from pathlib import Path

import torch
import torch.nn as nn
import torch.utils


'''record configurations'''
class record_config():
    def __init__(self, args):
        now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
        today = datetime.date.today()

        self.args = args
        self.job_dir = Path(args.job_dir)

        def _make_dir(path):
            if not os.path.exists(path):
                os.makedirs(path)

        _make_dir(self.job_dir)

        config_dir = self.job_dir / 'config.txt'
        #if not os.path.exists(config_dir):
        if args.resume:
            with open(config_dir, 'a') as f:
                f.write(now + '\n\n')
                for arg in vars(args):
                    f.write('{}: {}\n'.format(arg, getattr(args, arg)))
                f.write('\n')
        else:
            with open(config_dir, 'w') as f:
                f.write(now + '\n\n')
                for arg in vars(args):
                    f.write('{}: {}\n'.format(arg, getattr(args, arg)))
                f.write('\n')


def get_logger(file_path):

    logger = logging.getLogger('gal')
    log_format = '%(asctime)s | %(message)s'
    formatter = logging.Formatter(log_format, datefmt='%m/%d %I:%M:%S %p')
    file_handler = logging.FileHandler(file_path)
    file_handler.setFormatter(formatter)
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)

    logger.addHandler(file_handler)
    logger.addHandler(stream_handler)
    logger.setLevel(logging.INFO)

    return logger

#label smooth
class CrossEntropyLabelSmooth(nn.Module):

  def __init__(self, num_classes, epsilon):
    super(CrossEntropyLabelSmooth, self).__init__()
    self.num_classes = num_classes
    self.epsilon = epsilon
    self.logsoftmax = nn.LogSoftmax(dim=1)

  def forward(self, inputs, targets):
    log_probs = self.logsoftmax(inputs)
    targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
    targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
    loss = (-targets * log_probs).mean(0).sum()
    return loss


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print(' '.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def save_checkpoint(state, is_best, save):
    if not os.path.exists(save):
        os.makedirs(save)
    filename = os.path.join(save, 'checkpoint.pth.tar')
    torch.save(state, filename)
    if is_best:
        best_filename = os.path.join(save, 'model_best.pth.tar')
        shutil.copyfile(filename, best_filename)


def adjust_learning_rate(optimizer, epoch, args):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res



def progress_bar(current, total, msg=None):
    _, term_width = os.popen('stty size', 'r').read().split()
    term_width = int(term_width)

    TOTAL_BAR_LENGTH = 65.
    last_time = time.time()
    begin_time = last_time

    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Tot: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))

    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()


def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f

In [None]:
def train(epoch, train_loader, model, criterion, optimizer, scheduler):
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')

    model.train()

    cur_lr = optimizer.param_groups[0]['lr']
    print('learning_rate: ' + str(cur_lr))

    num_iter = len(train_loader)
    print_freq = num_iter // 10
    for i, (images, target) in enumerate(train_loader):
        images = images.cuda()
        target = target.cuda()

        # compute output
        logits = model(images)
        loss = criterion(logits, target)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(logits, target, topk=(1, 5))
        n = images.size(0)
        losses.update(loss.item(), n)  # accumulated loss
        top1.update(prec1.item(), n)
        top5.update(prec5.item(), n)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % print_freq == 0:
            print(
                'Epoch[{0}]({1}/{2}): '
                'Loss {loss.avg:.4f} '
                'Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f} '
                'Lr {cur_lr:.4f}'.format(
                    epoch, i, num_iter, loss=losses,
                    top1=top1, top5=top5, cur_lr=cur_lr))
    scheduler.step()

    return losses.avg, top1.avg, top5.avg


def validate(val_loader, model, criterion):
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')

    # switch to evaluation mode
    model.eval()
    with torch.no_grad():
        for i, (images, target) in enumerate(val_loader):
            images = images.cuda()
            target = target.cuda()

            # compute output
            logits = model(images)
            loss = criterion(logits, target)

            # measure accuracy and record loss
            pred1, pred5 = accuracy(logits, target, topk=(1, 5))
            n = images.size(0)
            losses.update(loss.item(), n)
            top1.update(pred1[0], n)
            top5.update(pred5[0], n)

        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
                    .format(top1=top1, top5=top5))

    return losses.avg, top1.avg, top5.avg

In [None]:
import torchvision
from torchvision import datasets, transforms

def load_data(batch_size=128):

    # load training data
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    trainset = torchvision.datasets.CIFAR10(root="./", train=True, download=True,
                                            transform=transform_train)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    testset = torchvision.datasets.CIFAR10(root="./", train=False, download=True, transform=transform_test)
    val_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

    return train_loader, val_loader

In [None]:
# parameters
epochs = 100
lr_warmup_epochs=5
lr=0.01
momentum=0.9
weight_decay=5e-4
lr_warmup_decay=0.01

In [None]:
def finetune(model, train_loader, val_loader, epochs, criterion):
    optimizer = torch.optim.SGD(model.parameters(
    ), lr=lr, momentum=momentum, weight_decay=weight_decay)
    main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=epochs-lr_warmup_epochs)
    warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer, start_factor=lr_warmup_decay, total_iters=lr_warmup_epochs)
    scheduler = torch.optim.lr_scheduler.SequentialLR(
        optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[lr_warmup_epochs])

    _, best_top1_acc, _ = validate(val_loader, model, criterion)
    best_model_state = copy.deepcopy(model.state_dict())
    epoch = 0
    while epoch < epochs:
        train(epoch, train_loader, model, criterion,
              optimizer, scheduler)
        _, valid_top1_acc, _ = validate(val_loader, model, criterion)

        if valid_top1_acc > best_top1_acc:
            best_top1_acc = valid_top1_acc
            best_model_state = copy.deepcopy(model.state_dict())


        epoch += 1
        print('=>Best accuracy {:.3f}'.format(best_top1_acc))

    model.load_state_dict(best_model_state)

    return model

# **Section 3: Load the pretrained baseline model**

In [None]:
!wget https://github.com/pvtien96/CORING/releases/download/v0.1.0/vgg_16_bn.pt

--2024-06-07 21:54:21--  https://github.com/pvtien96/CORING/releases/download/v0.1.0/vgg_16_bn.pt
Resolving github.com (github.com)... 140.82.121.3
Connecting to github.com (github.com)|140.82.121.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/572465934/6bb9aca3-1335-40ce-8a25-df08be78e4eb?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20240607%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240607T215422Z&X-Amz-Expires=300&X-Amz-Signature=8a0967615e46d1faa698e14ad0bc80d8521081574b7bc21b495c2b478933e307&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=572465934&response-content-disposition=attachment%3B%20filename%3Dvgg_16_bn.pt&response-content-type=application%2Foctet-stream [following]
--2024-06-07 21:54:22--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/572465934/6bb9aca3-1335-40ce-8a25-df08be78e4eb?X-Amz-Algorit

In [None]:
 import copy


 # initialize model
model_ori = vgg_16_bn(compress_rate=[0.0]*13).cuda()
print(model_ori)

# load training data
train_loader, val_loader = load_data()
criterion = nn.CrossEntropyLoss().cuda()

# load the baseline model
checkpoint = torch.load("./vgg_16_bn.pt", map_location=torch.device('cuda:0'))
model_ori.load_state_dict(checkpoint['state_dict'])



VGG(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace=True)
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu3): ReLU(inplace=True)
    (conv4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu4): ReLU(inplace=True)
    (pool5): MaxPool2d(kernel_size=2, stride=2, padding=0, di

100%|██████████| 170498071/170498071 [00:05<00:00, 29847215.20it/s]


Extracting ./cifar-10-python.tar.gz to ./
Files already downloaded and verified


<All keys matched successfully>

In [None]:
print("Evaluating the baseline model:")
_, accuracy_model_ori, _ = validate(val_loader, model_ori, criterion)
print(f"This model's accuracy is {accuracy_model_ori}")

Evaluating the baseline model:
 * Acc@1 93.960 Acc@5 99.730
This model's accuracy is 93.95999908447266


In [None]:
! pip install ptflops

Collecting ptflops
  Downloading ptflops-0.7.3-py3-none-any.whl (18 kB)
Installing collected packages: ptflops
Successfully installed ptflops-0.7.3


In [None]:
from ptflops import get_model_complexity_info
with torch.cuda.device(0):
  macs, params = get_model_complexity_info(model_ori, (3, 32, 32), as_strings=False, print_per_layer_stat=True, verbose=False)

VGG(
  14.99 M, 100.000% Params, 314.69 MMac, 99.872% MACs, 
  (features): Sequential(
    14.72 M, 98.207% Params, 314.43 MMac, 99.787% MACs, 
    (conv0): Conv2d(1.79 k, 0.012% Params, 1.84 MMac, 0.582% MACs, 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm0): BatchNorm2d(128, 0.001% Params, 131.07 KMac, 0.042% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(0, 0.000% Params, 65.54 KMac, 0.021% MACs, inplace=True)
    (conv1): Conv2d(36.93 k, 0.246% Params, 37.81 MMac, 12.001% MACs, 64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm1): BatchNorm2d(128, 0.001% Params, 131.07 KMac, 0.042% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU(0, 0.000% Params, 65.54 KMac, 0.021% MACs, inplace=True)
    (pool2): MaxPool2d(0, 0.000% Params, 65.54 KMac, 0.021% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv3): Conv2d(73.86 k, 0.493% Params,

In [None]:
print(f"The number of parameter and MACs of this model are {params} and {macs}, respectively.")

The number of parameter and MACs of this model are 14991946 and 315096586, respectively.


\# ** 3 methods to prune the model**

*   Random
*   Norm-based
*   Distance-based



In [None]:
compress_rate = [0.25]*13 # prune 25% of all layers
model_prune = vgg_16_bn(compress_rate=compress_rate).cuda()
print(model_prune)


VGG(
  (features): Sequential(
    (conv0): Conv2d(3, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm0): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (conv1): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace=True)
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv3): Conv2d(48, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu3): ReLU(inplace=True)
    (conv4): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm4): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu4): ReLU(inplace=True)
    (pool5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilatio

## ** Random**


In [None]:
def prune_random(model, model_ori):
    oristate_dict = model_ori.state_dict()
    state_dict = model.state_dict()
    last_select_index = None  # Conv index selected in the previous layer

    cnt = 0
    for name, module in model.named_modules():
        name = name.replace('module.', '')

        if isinstance(module, nn.Conv2d):
            cnt += 1
            oriweight = oristate_dict[name + '.weight']
            curweight = state_dict[name + '.weight']
            orifilter_num = oriweight.size(0)
            currentfilter_num = curweight.size(0)
            print(f"Processing layer {cnt}, original layer has {orifilter_num} filters, pruning model has {currentfilter_num} filters")


            if orifilter_num != currentfilter_num:
                cov_id = cnt
                #************ rank the filter's importance here
                rank = np.arange(1, orifilter_num + 1)
                np.random.shuffle(rank)
                #********************
                print(f"rank {rank}")
                select_index = np.argsort(
                    rank)[orifilter_num-currentfilter_num:]  # preserved filter id
                select_index.sort()

                if last_select_index is not None:
                    for index_i, i in enumerate(select_index):
                        for index_j, j in enumerate(last_select_index):
                            state_dict[name + '.weight'][index_i][index_j] = \
                                oristate_dict[name + '.weight'][i][j]
                else:
                    for index_i, i in enumerate(select_index):
                        state_dict[name + '.weight'][index_i] = \
                            oristate_dict[name + '.weight'][i]

                last_select_index = select_index

            elif last_select_index is not None:
                for i in range(orifilter_num):
                    for index_j, j in enumerate(last_select_index):
                        state_dict[name + '.weight'][i][index_j] = \
                            oristate_dict[name + '.weight'][i][j]
            else:
                state_dict[name + '.weight'] = oriweight
                last_select_index = None

    model.load_state_dict(state_dict)

In [None]:
prune_random(model_prune, model_ori)

Processing layer 1, original layer has 64 filters, pruning model has 48 filters
rank [59 56 62  8 44 43 27 45 29  1 39 22 16 20 24 10  6 26 36 31 41  3 23 37
 55 57  4  9 58 21 13 49 40  2 42 14 46 28  7 30 60 52 50 18 61  5 48 53
 15 12 64 11 32 38 35 51 33 63 25 54 47 17 34 19]
Processing layer 2, original layer has 64 filters, pruning model has 48 filters
rank [41 13 17 23 20 21 44 49 35 63 56 61 10 47 12 53  8 42 31 32 48 39 58 29
 27 62 37 52 22 33 16 55 43 50  1 11 64 40 30 34 38  2 51  9 59  5 15 25
 60 26 36 14 19 24  7 18 57 28  4 45 54  3  6 46]
Processing layer 3, original layer has 128 filters, pruning model has 96 filters
rank [ 88  41 107  37  67  95  76  32   2  26  89 124 105  18  59  25 118  16
  23  14  15  83 104  82  96  13 109  80 123 119  63  43  34  92 127  84
  62  17 102  53  55  61 113  29  44  94  78 101 116  19  77  98  99 121
  49   9  28  33  81  24  70  86  42  75  11   5  69  90 125  31 117   6
  40  35  60  72  66  20  85  57  74  36  79  21   4 106 115

In [None]:
finetune(model_prune, train_loader, val_loader, epochs=1, criterion=criterion)

 * Acc@1 10.000 Acc@5 50.000
learning_rate: 0.0001
Epoch[0](0/391): Loss 2.4927 Prec@1(1,5) 8.59, 46.09 Lr 0.0001
Epoch[0](39/391): Loss 2.1735 Prec@1(1,5) 27.79, 64.36 Lr 0.0001
Epoch[0](78/391): Loss 1.8697 Prec@1(1,5) 44.94, 77.78 Lr 0.0001
Epoch[0](117/391): Loss 1.6595 Prec@1(1,5) 53.45, 83.73 Lr 0.0001
Epoch[0](156/391): Loss 1.5107 Prec@1(1,5) 58.09, 86.97 Lr 0.0001
Epoch[0](195/391): Loss 1.3916 Prec@1(1,5) 61.68, 89.06 Lr 0.0001
Epoch[0](234/391): Loss 1.2990 Prec@1(1,5) 64.31, 90.55 Lr 0.0001
Epoch[0](273/391): Loss 1.2259 Prec@1(1,5) 66.35, 91.60 Lr 0.0001
Epoch[0](312/391): Loss 1.1641 Prec@1(1,5) 67.92, 92.44 Lr 0.0001
Epoch[0](351/391): Loss 1.1081 Prec@1(1,5) 69.45, 93.12 Lr 0.0001
Epoch[0](390/391): Loss 1.0614 Prec@1(1,5) 70.69, 93.69 Lr 0.0001
 * Acc@1 80.220 Acc@5 98.540
=>Best accuracy 80.220


VGG(
  (features): Sequential(
    (conv0): Conv2d(3, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm0): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (conv1): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace=True)
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv3): Conv2d(48, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu3): ReLU(inplace=True)
    (conv4): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm4): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu4): ReLU(inplace=True)
    (pool5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilatio

In [None]:
with torch.cuda.device(0):
  macs_prune, params_prune = get_model_complexity_info(model_prune, (3, 32, 32), as_strings=False, print_per_layer_stat=True, verbose=False)

VGG(
  8.49 M, 100.000% Params, 177.63 MMac, 99.831% MACs, 
  (features): Sequential(
    8.28 M, 97.605% Params, 177.43 MMac, 99.716% MACs, 
    (conv0): Conv2d(1.34 k, 0.016% Params, 1.38 MMac, 0.773% MACs, 3, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm0): BatchNorm2d(96, 0.001% Params, 98.3 KMac, 0.055% MACs, 48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(0, 0.000% Params, 49.15 KMac, 0.028% MACs, inplace=True)
    (conv1): Conv2d(20.78 k, 0.245% Params, 21.28 MMac, 11.961% MACs, 48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm1): BatchNorm2d(96, 0.001% Params, 98.3 KMac, 0.055% MACs, 48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU(0, 0.000% Params, 49.15 KMac, 0.028% MACs, inplace=True)
    (pool2): MaxPool2d(0, 0.000% Params, 49.15 KMac, 0.028% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv3): Conv2d(41.57 k, 0.490% Params, 10.64 M

## ** Norm**

In [None]:
def prune_norm(model, model_ori):
    oristate_dict = model_ori.state_dict()
    state_dict = model.state_dict()
    last_select_index = None  # Conv index selected in the previous layer

    cnt = 0
    for name, module in model.named_modules():
        name = name.replace('module.', '')

        if isinstance(module, nn.Conv2d):
            cnt += 1
            oriweight = oristate_dict[name + '.weight']
            curweight = state_dict[name + '.weight']
            orifilter_num = oriweight.size(0)
            currentfilter_num = curweight.size(0)
            print(f"Processing layer {cnt}, original layer has {orifilter_num} filters, pruning model has {currentfilter_num} filters")


            if orifilter_num != currentfilter_num:
                cov_id = cnt
                #************ rank the filter's importance here
                print(oristate_dict[name + '.weight'].shape)
                weight = oristate_dict[name + '.weight'].data
                weight = weight.reshape(weight.size(0), weight.size(1)*weight.size(2)*weight.size(3))
                norms = torch.norm(weight, dim=1)  # Compute norm along dimensions 1, 2, and 3
                print(norms)

                # Now, let's rank them based on the norms.
                # We'll get the indices that would sort the norms in descending order.
                sorted_indices = torch.argsort(norms, descending=True)

                # Print the ranks and corresponding norms
                for rank, index in enumerate(sorted_indices):
                    norm_value = norms[index]
                    # print(f"Rank {rank + 1}: Norm = {norm_value.item()}")

                # If you also want the indices of filters in descending order of their norms
                # print("Indices of filters in descending order of their norms:")
                # print(sorted_indices)
                rank = sorted_indices.cpu().numpy()
                #********************
                print(f"rank {rank}")
                select_index = np.argsort(
                    rank)[orifilter_num-currentfilter_num:]  # preserved filter id
                select_index.sort()

                if last_select_index is not None:
                    for index_i, i in enumerate(select_index):
                        for index_j, j in enumerate(last_select_index):
                            state_dict[name + '.weight'][index_i][index_j] = \
                                oristate_dict[name + '.weight'][i][j]
                else:
                    for index_i, i in enumerate(select_index):
                        state_dict[name + '.weight'][index_i] = \
                            oristate_dict[name + '.weight'][i]

                last_select_index = select_index

            elif last_select_index is not None:
                for i in range(orifilter_num):
                    for index_j, j in enumerate(last_select_index):
                        state_dict[name + '.weight'][i][index_j] = \
                            oristate_dict[name + '.weight'][i][j]
            else:
                state_dict[name + '.weight'] = oriweight
                last_select_index = None

    model.load_state_dict(state_dict)

In [None]:
prune_norm(model_prune, model_ori)


Processing layer 1, original layer has 64 filters, pruning model has 48 filters
torch.Size([64, 3, 3, 3])
tensor([1.5891e+00, 8.8759e-01, 5.5363e-01, 3.3376e-01, 2.2699e-01, 7.8806e-01,
        3.8824e-01, 3.9319e-01, 3.0613e-02, 6.7065e-01, 3.2999e-01, 1.5340e+00,
        4.5140e-02, 5.1688e-04, 1.4907e-01, 2.3611e-01, 1.9442e-01, 1.1847e+00,
        2.6520e-01, 1.0161e-01, 2.3032e-01, 6.9368e-01, 1.2299e-01, 1.0715e+00,
        5.8923e-01, 5.8216e-01, 2.8645e-01, 3.3000e-01, 1.6501e+00, 1.1890e+00,
        9.7305e-01, 5.8581e-01, 1.9635e+00, 7.2750e-04, 1.9415e-01, 3.9707e-04,
        6.7574e-01, 1.8930e-03, 7.1398e-01, 1.0104e+00, 1.0242e+00, 6.8070e-01,
        9.9616e-01, 5.0565e-01, 4.1435e-01, 3.0730e-01, 3.4557e-01, 1.3315e+00,
        8.0923e-04, 1.0445e+00, 6.0952e-04, 9.9565e-02, 7.0488e-01, 7.0627e-04,
        1.7248e+00, 3.9991e-01, 2.0378e-01, 1.7900e+00, 7.4258e-01, 5.9122e-01,
        1.4910e-01, 5.9228e-01, 1.1001e-01, 4.1941e-01], device='cuda:0')
rank [32 57 54 28  0

In [None]:
print("Evaluating the model after pruning, without finetuning:")
_, accuracy_model_prune, _ = validate(val_loader, model_prune, criterion)
print(f"This model's accuracy is {accuracy_model_prune}")

Evaluating the model after pruning, without finetuning:
 * Acc@1 79.580 Acc@5 98.380
This model's accuracy is 79.57999420166016


In [None]:
finetune(model_prune, train_loader, val_loader, epochs=1, criterion=criterion)

 * Acc@1 10.000 Acc@5 50.000
learning_rate: 0.0001
Epoch[0](0/391): Loss 2.0290 Prec@1(1,5) 32.03, 78.91 Lr 0.0001
Epoch[0](39/391): Loss 1.8653 Prec@1(1,5) 43.65, 87.09 Lr 0.0001
Epoch[0](78/391): Loss 1.6173 Prec@1(1,5) 55.03, 91.25 Lr 0.0001
Epoch[0](117/391): Loss 1.4574 Prec@1(1,5) 60.12, 93.00 Lr 0.0001
Epoch[0](156/391): Loss 1.3378 Prec@1(1,5) 63.39, 94.09 Lr 0.0001
Epoch[0](195/391): Loss 1.2434 Prec@1(1,5) 66.09, 94.85 Lr 0.0001
Epoch[0](234/391): Loss 1.1696 Prec@1(1,5) 67.98, 95.40 Lr 0.0001
Epoch[0](273/391): Loss 1.1052 Prec@1(1,5) 69.68, 95.82 Lr 0.0001
Epoch[0](312/391): Loss 1.0546 Prec@1(1,5) 70.97, 96.18 Lr 0.0001
Epoch[0](351/391): Loss 1.0121 Prec@1(1,5) 72.01, 96.39 Lr 0.0001
Epoch[0](390/391): Loss 0.9718 Prec@1(1,5) 73.06, 96.61 Lr 0.0001
 * Acc@1 79.580 Acc@5 98.380
=>Best accuracy 79.580


VGG(
  (features): Sequential(
    (conv0): Conv2d(3, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm0): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (conv1): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace=True)
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv3): Conv2d(48, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu3): ReLU(inplace=True)
    (conv4): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm4): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu4): ReLU(inplace=True)
    (pool5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilatio

## **Similarity**


In [None]:
def prune_similarity(model, model_ori):
    oristate_dict = model_ori.state_dict()
    state_dict = model.state_dict()
    last_select_index = None  # Conv index selected in the previous layer

    cnt = 0
    for name, module in model.named_modules():
        name = name.replace('module.', '')

        if isinstance(module, nn.Conv2d):
            cnt += 1
            oriweight = oristate_dict[name + '.weight']
            curweight = state_dict[name + '.weight']
            orifilter_num = oriweight.size(0)
            currentfilter_num = curweight.size(0)
            print(f"Processing layer {cnt}, original layer has {orifilter_num} filters, pruning model has {currentfilter_num} filters")

            if orifilter_num != currentfilter_num:
                cov_id = cnt
                #************ rank the filter's importance here
                # print(oristate_dict[name + '.weight'].shape)
                weight = oristate_dict[name + '.weight'].data
                similarity_matrix = np.zeros((orifilter_num, orifilter_num))
                for i in range(orifilter_num):
                  for j in range(orifilter_num):
                    # print(f'Computing the distance between filter {i} and filter {j}:')
                    dist = torch.dist(weight[i], weight[j])
                    similarity_matrix[i, j] = dist
                    # print(dist)

                print(similarity_matrix)
                row_sums = np.sum(similarity_matrix, axis=1) # compute the sum of the distance of 1 filter to all other filters
                rank = row_sums
                #********************
                # print(f"rank {rank}")
                select_index = np.argsort(
                    rank)[orifilter_num-currentfilter_num:]  # preserved filter id
                select_index.sort()

                if last_select_index is not None:
                    for index_i, i in enumerate(select_index):
                        for index_j, j in enumerate(last_select_index):
                            state_dict[name + '.weight'][index_i][index_j] = \
                                oristate_dict[name + '.weight'][i][j]
                else:
                    for index_i, i in enumerate(select_index):
                        state_dict[name + '.weight'][index_i] = \
                            oristate_dict[name + '.weight'][i]

                last_select_index = select_index

            elif last_select_index is not None:
                for i in range(orifilter_num):
                    for index_j, j in enumerate(last_select_index):
                        state_dict[name + '.weight'][i][index_j] = \
                            oristate_dict[name + '.weight'][i][j]
            else:
                state_dict[name + '.weight'] = oriweight
                last_select_index = None

    model.load_state_dict(state_dict)

In [None]:
prune_similarity(model_prune, model_ori)

Processing layer 1, original layer has 64 filters, pruning model has 48 filters
[[0.         1.74139035 1.59087288 ... 1.72806299 1.57414818 1.61237991]
 [1.74139035 0.         1.0449996  ... 1.0066036  0.88404632 0.85878301]
 [1.59087288 1.0449996  0.         ... 0.66931009 0.51807326 0.4106819 ]
 ...
 [1.72806299 1.0066036  0.66931009 ... 0.         0.57510459 0.59303981]
 [1.57414818 0.88404632 0.51807326 ... 0.57510459 0.         0.43026134]
 [1.61237991 0.85878301 0.4106819  ... 0.59303981 0.43026134 0.        ]]
Processing layer 2, original layer has 64 filters, pruning model has 48 filters
[[0.         1.18950224 1.54022479 ... 1.15816128 1.24496245 1.15282464]
 [1.18950224 0.         1.11939645 ... 1.11556244 0.97027659 1.26536036]
 [1.54022479 1.11939645 0.         ... 1.34018302 1.27003932 1.45868313]
 ...
 [1.15816128 1.11556244 1.34018302 ... 0.         1.18896043 1.27326846]
 [1.24496245 0.97027659 1.27003932 ... 1.18896043 0.         1.27268112]
 [1.15282464 1.26536036 1.

In [None]:
print("Evaluating the model after pruning, without finetuning:")
_, accuracy_model_prune, _ = validate(val_loader, model_prune, criterion)
print(f"This model's accuracy is {accuracy_model_prune}")

Evaluating the model after pruning, without finetuning:
 * Acc@1 10.000 Acc@5 50.020
This model's accuracy is 10.0


In [None]:
finetune(model_prune, train_loader, val_loader, epochs=1, criterion=criterion)

 * Acc@1 10.000 Acc@5 50.020
learning_rate: 0.0001
Epoch[0](0/391): Loss 2.1867 Prec@1(1,5) 20.31, 67.19 Lr 0.0001
Epoch[0](39/391): Loss 1.9700 Prec@1(1,5) 33.50, 78.67 Lr 0.0001
Epoch[0](78/391): Loss 1.7097 Prec@1(1,5) 49.00, 85.36 Lr 0.0001
Epoch[0](117/391): Loss 1.5341 Prec@1(1,5) 55.89, 88.79 Lr 0.0001
Epoch[0](156/391): Loss 1.4014 Prec@1(1,5) 60.56, 90.80 Lr 0.0001
Epoch[0](195/391): Loss 1.2996 Prec@1(1,5) 63.79, 92.12 Lr 0.0001
Epoch[0](234/391): Loss 1.2187 Prec@1(1,5) 66.18, 93.06 Lr 0.0001
Epoch[0](273/391): Loss 1.1466 Prec@1(1,5) 68.31, 93.80 Lr 0.0001
Epoch[0](312/391): Loss 1.0864 Prec@1(1,5) 69.98, 94.41 Lr 0.0001
Epoch[0](351/391): Loss 1.0340 Prec@1(1,5) 71.46, 94.85 Lr 0.0001
Epoch[0](390/391): Loss 0.9885 Prec@1(1,5) 72.69, 95.26 Lr 0.0001
 * Acc@1 80.640 Acc@5 98.350
=>Best accuracy 80.640


VGG(
  (features): Sequential(
    (conv0): Conv2d(3, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm0): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (conv1): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace=True)
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv3): Conv2d(48, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu3): ReLU(inplace=True)
    (conv4): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm4): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu4): ReLU(inplace=True)
    (pool5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilatio

In [None]:
!pip install torch torchvision opencv-python




In [None]:
# import cv2
# import numpy as np
# import torch
# import torch.nn as nn
# import torchvision.transforms as transforms
# from torchvision.models import vgg16_bn
# from google.colab.patches import cv2_imshow

# class CustomVGG16(nn.Module):
#     def __init__(self, original_vgg16):
#         super(CustomVGG16, self).__init__()
#         self.features = original_vgg16.features
#         self.avgpool = original_vgg16.avgpool
#         self.classifier = original_vgg16.classifier

#     def forward(self, x):
#         x = self.features(x)
#         x = self.avgpool(x)
#         x = torch.flatten(x, 1)
#         x = self.classifier(x)
#         return x

# # Load pre-trained VGG16 models
# model_ori = vgg16_bn(weights="IMAGENET1K_V1").cuda()
# model_prune = vgg16_bn(weights=None).cuda()  # Assuming the pruned model definition is the same

# # Wrap the models in CustomVGG16 to ensure proper handling of input dimensions
# model_ori = CustomVGG16(model_ori)
# model_prune = CustomVGG16(model_prune)

# # Load the video
# video_path = '/content/video.mp4'
# cap = cv2.VideoCapture(video_path)

# # Video writer to save the outputs
# fourcc = cv2.VideoWriter_fourcc(*'XVID')
# out_original = cv2.VideoWriter('original_output.avi', fourcc, 20.0, (640, 480))
# out_pruned = cv2.VideoWriter('pruned_output.avi', fourcc, 20.0, (640, 480))

# # Preprocessing transformation
# transform = transforms.Compose([
#     transforms.ToPILImage(),
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])

# def preprocess_frame(frame):
#     frame_tensor = transform(frame).unsqueeze(0).cuda()
#     return frame_tensor

# # Function to run inference
# def run_inference(model, frame):
#     model.eval()
#     with torch.no_grad():
#         output = model(frame)
#     return output

# # Inference and comparison loop
# while cap.isOpened():
#     ret, frame = cap.read()
#     if not ret:
#         break

#     # Preprocess the frame
#     preprocessed_frame = preprocess_frame(frame)

#     # Run inference
#     original_output = run_inference(model_ori, preprocessed_frame)
#     pruned_output = run_inference(model_prune, preprocessed_frame)

#     # Assuming the output is a classification, get the predicted class
#     _, original_pred = torch.max(original_output, 1)
#     _, pruned_pred = torch.max(pruned_output, 1)

#     # Overlay predictions on the frame (dummy example)
#     frame_original = frame.copy()
#     frame_pruned = frame.copy()

#     # Display the predictions (dummy text overlay)
#     cv2.putText(frame_original, f'Original: {original_pred.item()}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)
#     cv2.putText(frame_pruned, f'Pruned: {pruned_pred.item()}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)

#     # Write the frames with detections
#     out_original.write(frame_original)
#     out_pruned.write(frame_pruned)

#     # Display the frames using cv2_imshow
#     cv2_imshow(frame_original)
#     cv2_imshow(frame_pruned)

#     # Optional delay between frames
#     if cv2.waitKey(1) & 0xFF == ord('q'):
#         break

# cap.release()
# out_original.release()
# out_pruned.release()
# cv2.destroyAllWindows()




# import cv2
# import torch
# import torchvision.transforms as transforms
# from PIL import Image
# import numpy as np
# import time

# # Function to extract frames from video
# def extract_frames(video_path, resize_shape=(224, 224)):
#     cap = cv2.VideoCapture(video_path)
#     frames = []
#     while cap.isOpened():
#         ret, frame = cap.read()
#         if not ret:
#             break
#         frame = cv2.resize(frame, resize_shape)
#         frames.append(frame)
#     cap.release()
#     return frames

# # Function to run inference and calculate FPS
# def run_inference(frames, model, device):
#     transform = transforms.Compose([
#         transforms.ToPILImage(),
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
#     ])
#     fps_list = []
#     results = []
#     model.eval()
#     with torch.no_grad():
#         for frame in frames:
#             image = transform(frame).unsqueeze(0).to(device)
#             start_time = time.time()
#             outputs = model(image)
#             end_time = time.time()
#             fps = 1 / (end_time - start_time)
#             fps_list.append(fps)
#             results.append(outputs.cpu().numpy())
#     avg_fps = np.mean(fps_list)
#     return results, avg_fps

# # Function to visualize results
# def visualize_results(frames, results, fps, title):
#     for i, frame in enumerate(frames):
#         result = results[i]
#         # Add code to draw bounding boxes and labels based on the results
#         # For simplicity, we'll just display the frame and FPS here
#         cv2.putText(frame, f"FPS: {fps:.2f}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
#         cv2.imshow(title, frame)
#         if cv2.waitKey(25) & 0xFF == ord('q'):
#             break
#     cv2.destroyAllWindows()

# # Load the video
# video_path = '/content/video.mp4'
# frames = extract_frames(video_path)

# # Ensure device is set
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model_ori = model_ori.to(device)
# model_prune = model_prune.to(device)

# # Run inference and measure FPS
# results_ori, fps_ori = run_inference(frames, model_ori, device)
# results_prune, fps_prune = run_inference(frames, model_prune, device)

# # Visualize results
# visualize_results(frames, results_ori, fps_ori, "Original Model")
# visualize_results(frames, results_prune, fps_prune, "Pruned Model")

# print(f"Original Model FPS: {fps_ori}")
# print(f"Pruned Model FPS: {fps_prune}")



NameError: name 'model_ori' is not defined