In [None]:
import torch
from torch import nn
from torch.optim import AdamW
from torch.nn import functional as F

from avalanche.evaluation.metrics.accuracy import Accuracy

from tqdm import tqdm

import timm
from timm.models import create_model
from timm.models.layers import DropPath
from timm.scheduler.cosine_lr import CosineLRScheduler

import math
import random
import os
import numpy as np
import time

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split, Dataset

# Import custom Dataset class from vtab folder
from vtab.Cifar import CifarDataPytorch

# Import Convpass function for model manipulation
from convpass.convbyppass import set_Convpass

# Import custom utility functions
from utils import *

In [None]:
def train(model, dl, test_dl, opt, scheduler, method, dataset, epoch = 100):
    model.train()
    model = model.cuda()
    best_acc = 0
    for ep in tqdm(range(epoch)):
        model.train()
        model = model.cuda()
        for i, batch in enumerate(dl):
            x, y = batch[0].cuda(), batch[1].cuda()
            out = model(x)
            loss = F.cross_entropy(out, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
        if scheduler is not None:
            scheduler.step(ep)
        if ep % 10 == 9:
            acc, _, _ = test(model, test_dl)
            if acc > best_acc:
                best_acc = acc
                save(method, dataset, model, acc, ep)
    model = model.cpu()
    return model, best_acc


@torch.no_grad()
def test(model, dl):
    model.eval()
    acc = Accuracy()
    total_time = 0
    top5, total = 0, 0

    model = model.cuda()
    for batch in dl:  
        x, y = batch[0].cuda(), batch[1].cuda()
        start_time = time.time()
        out = model(x).data
        inference_time = time.time() - start_time
        total_time += inference_time

        _, pred = out.topk(5, 1, True, True)
        pred = pred.t()
        correct = pred.eq(y.view(1, -1).expand_as(pred))
        top5 += correct[:5].reshape(-1).float().sum(0, keepdim=True)
        total += y.size(0)

        acc.update(out.argmax(dim=1).view(-1), y)

    print(acc.result())
    top5_acc = top5 / total
    mean_inference_time = total_time / len(dl)

    return acc.result(), mean_inference_time, top5_acc

def count_finetuned_params(model):
    num_finetuned_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return num_finetuned_params

def count_total_params(model):
    total_params = sum(p.numel() for p in model.parameters())
    return total_params

In [None]:
if not os.path.exists('./models/convpass'):
    os.makedirs('./models/convpass')

if not os.path.exists('./data'):
    os.makedirs('./data')

In [None]:
dataset_name = 'cifar100'
lr = 1e-3
wd = 1e-4
method_name = 'convpass'
epoch = 100
class_num = 100

In [None]:
# CIFAR100
cifar100 = CifarDataPytorch(num_classes=100, data_dir='./data/cifar100', train_split_percent=80, batch_size=64)
train_loader_cifar, val_loader_cifar, test_loader_cifar = cifar100.get_loaders()

In [None]:
train_loader = train_loader_cifar
val_loader = val_loader_cifar
test_loader = test_loader_cifar

print(len(train_loader.dataset), len(val_loader.dataset), len(test_loader.dataset))

Convpass Model

In [None]:
model = create_model('deit_tiny_distilled_patch16_224', pretrained=True, drop_path_rate=0.1)

set_Convpass(model, 'convpass', dim=8, s=0.1, xavier_init=False, distilled=True)

trainable = []
model.reset_classifier(class_num)

for n, p in model.named_parameters():
    if 'adapter' in n or 'head' in n:
        trainable.append(p)
    else:
        p.requires_grad = False

opt = AdamW(trainable, lr=lr, weight_decay=wd)
scheduler = CosineLRScheduler(opt, t_initial=100,
                                  warmup_t=10, lr_min=1e-5, warmup_lr_init=1e-6)


model, acc = train(model, train_loader, val_loader,
                   opt, scheduler, method_name, dataset_name)

In [None]:
model_tuned = load(method_name, dataset_name, model)

num_finetuned_params = count_finetuned_params(model_tuned)
num_total_params = count_total_params(model_tuned)
acc, inference_mean, top5_acc = test(model, test_loader)

print(f"Number of parameters fine-tuned: {num_finetuned_params}")
print(f"Total number of parameters: {num_total_params}")
print(f"Share: {num_finetuned_params/num_total_params}")

print('Accuracy:', acc)
print(f"Mean inference time per batch: {inference_mean:.4f} seconds")
print(f'Top 5 Acc: {top5_acc}')

Fully-tuned model

In [None]:
model_fully_tuned = create_model('deit_tiny_distilled_patch16_224', pretrained=True, 
                     drop_path_rate=0.1)
model_fully_tuned.reset_classifier(class_num)

for n, p in model_fully_tuned.named_parameters():
        p.requires_grad = True

opt = AdamW(model_fully_tuned.parameters(), lr=lr, weight_decay=wd)
scheduler = CosineLRScheduler(opt, t_initial=100,
                                  warmup_t=10, lr_min=1e-5, warmup_lr_init=1e-6)


model_deit_trained_fulltuned, acc = train(model_fully_tuned, train_loader, val_loader,
                   opt, scheduler, method_name, dataset_name)

acc, inference_mean, top5_acc = test(model_deit_trained_fulltuned, test_loader)

num_finetuned_params = count_finetuned_params(model_deit_trained_fulltuned)
num_total_params = count_total_params(model_deit_trained_fulltuned)

print(f"Number of parameters fine-tuned: {num_finetuned_params}")
print(f"Total number of parameters: {num_total_params}")
print(f"Share: {num_finetuned_params/num_total_params}")

print('Accuracy:', acc)
print(f"Mean inference time per batch: {inference_mean:.4f} seconds")
print(f'Top 5 Acc: {top5_acc}')

Linear head-tuned model

In [None]:
model_head_tuned = create_model('deit_tiny_distilled_patch16_224', pretrained=True, 
                     drop_path_rate=0.1)
model_head_tuned.reset_classifier(class_num)

trainable = []

for n, p in model_head_tuned.named_parameters():
    if 'head' in n:
        trainable.append(p)
    else:
        p.requires_grad = False

opt = AdamW(trainable, lr=lr, weight_decay=wd)
scheduler = CosineLRScheduler(opt, t_initial=100,
                                  warmup_t=10, lr_min=1e-5, warmup_lr_init=1e-6)


model_deit_trained_headtuned, acc = train(model_head_tuned, train_loader, val_loader,
                   opt, scheduler, method_name, dataset_name)

acc, inference_mean, top5_acc = test(model_deit_trained_headtuned, test_loader)

num_finetuned_params = count_finetuned_params(model_deit_trained_headtuned)
num_total_params = count_total_params(model_deit_trained_headtuned)

print(f"Number of parameters fine-tuned: {num_finetuned_params}")
print(f"Total number of parameters: {num_total_params}")
print(f"Share: {num_finetuned_params/num_total_params}")

print('Accuracy:', acc)
print(f"Mean inference time per batch: {inference_mean:.4f} seconds")
print(f'Top 5 Acc: {top5_acc}')