In [1]:
import torch
import random
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data.dataset import Subset
from torch.utils.data import DataLoader
from torch.utils.data import RandomSampler
from torch.utils.data import SequentialSampler
import torch.nn as nn
import timm
import torch.nn.functional as F

In [2]:
def get_dataloaders(input_img_size, num_classes, num_tasks, classes_per_task, batch_size, shuffle, device):
    scale = (0.05, 1.0)
    ratio = (3./4., 4./3.)
    size = int((256/224) * input_img_size)
    pin = True if device == 'cuda' else False

    transforms_train =  transforms.Compose([
        transforms.RandomResizedCrop(input_img_size, scale=scale, ratio=ratio),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
    ])

    transforms_test = transforms.Compose([
        transforms.Resize(size, interpolation=3),
        transforms.CenterCrop(input_img_size),
        transforms.ToTensor(),
    ])

    dataset_train = datasets.CIFAR100(root='./data', train=True, download=True, transform=transforms_train)
    dataset_test = datasets.CIFAR100(root='./data', train=False, download=True, transform=transforms_test)

    split_datasets = list()
    class_mask = list()
    dataloaders_train = list()
    dataloaders_test = list()
    labels = [i for i in range(num_classes)]

    if shuffle:
        random.shuffle(labels)

    for i in range(num_tasks):
        scope = labels[:classes_per_task]
        labels = labels[classes_per_task:]

        train_split_indices = []
        test_split_indices = []

        for j in range(len(dataset_train.targets)):
            if int(dataset_train.targets[j]) in scope:
                train_split_indices.append(j)
        
        for j in range(len(dataset_test.targets)):
            if int(dataset_test.targets[j]) in scope:
                test_split_indices.append(j)

        train_subset = Subset(dataset_train, train_split_indices)
        test_subset = Subset(dataset_test, test_split_indices)

        split_datasets.append([train_subset, test_subset])
        class_mask.append(scope)

    for i in range(num_tasks):
        dataset_train_temp, dataset_test_temp = split_datasets[i]

        train_sampler = RandomSampler(dataset_train_temp)
        test_sampler = SequentialSampler(dataset_test_temp)

        dataloader_train_temp = DataLoader(dataset_train_temp, sampler=train_sampler, batch_size=batch_size, pin_memory=pin)
        dataloader_test_temp = DataLoader(dataset_test_temp, sampler=test_sampler, batch_size=batch_size, pin_memory=pin)

        dataloaders_train.append(dataloader_train_temp)
        dataloaders_test.append(dataloader_test_temp)

    return dataloaders_train, dataloaders_test, class_mask

In [3]:
class CustomLoss(nn.Module):
    def __init__(self, constant):
        super(CustomLoss, self).__init__()
        self.constant = constant
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, outputs, targets, value):
        value = value.to(outputs.device)

        ce_loss = self.cross_entropy(outputs, targets)
        cu_term = self.constant * value

        total_loss = ce_loss - cu_term

        return total_loss

In [4]:
class Prompt(nn.Module):
    def __init__(self, pool_size, keys_dim, prompt_dim, embed_dim, top_k):
        super(Prompt, self).__init__()
        self.pool_size = pool_size
        self.keys = nn.Parameter(torch.randn(pool_size, keys_dim)) # M x Embed_Dim
        self.prompts = nn.Parameter(torch.randn(pool_size, prompt_dim, embed_dim)) # M x Lp x Embed_Dim
        nn.init.uniform_(self.keys, -1, 1)
        nn.init.uniform_(self.prompts, -1, 1)
        self.k = top_k

    def select_prompt(self, queries): # Shape of Q (N x Embed_Dim)
        N, embed_dim = queries.shape[0], self.prompts.shape[2]

        queries_normalized = F.normalize(queries) # N x Embed_Dim
        keys_normalized = F.normalize(self.keys) # M x Embed_Dim

        cosine_similarity = torch.mm(queries_normalized, keys_normalized.T) # N x M

        _, indices = torch.topk(cosine_similarity, self.k, dim=1) # N x K

        prompt_idx, prompt_cnt = torch.unique(indices, sorted=True, return_counts=True) # 1D Tensor

        if prompt_idx.shape[0] < self.pool_size:
            min_value = torch.min(indices.flatten()).item()
            padding_size = self.pool_size - prompt_idx.shape[0]
            prompt_idx = torch.cat([prompt_idx, torch.full((padding_size,), min_value, device=prompt_idx.device)]) # 1D Tensor of size pool_size 
            prompt_cnt = torch.cat([prompt_cnt, torch.zeros((padding_size,), device=prompt_cnt.device)]) # 1D Tensor of size pool_size

        _, idx = torch.topk(prompt_cnt, self.k) # 1D Tensor of size to k

        final_indices = prompt_idx[idx].expand(N, -1) # N x K
        
        selected_prompts = self.prompts[final_indices].view(N, -1, embed_dim) # N x (K * Lp) x Embed_Dim

        selected_keys = keys_normalized[final_indices] # N x K x Embed_Dim
        queries_normalized = queries_normalized.unsqueeze(1) # N x 1 x Embed_Dim
        loss_term = selected_keys * queries_normalized # N x K x Emebd_Dim
        sum = torch.sum(loss_term) / N # 1 x 1

        return selected_prompts, sum # N x (K * Lp) x Embed_Dim and 1 x 1

In [5]:
class Model(nn.Module):
    def __init__(self, pool_size, prompt_dim, top_k, num_classes):
        super(Model, self).__init__()
        self.query_function = timm.create_model('vit_base_patch16_224.augreg2_in21k_ft_in1k', pretrained=True)
        self.embed_dim = self.query_function.embed_dim # 768
        self.prompt_pool = Prompt(pool_size, self.embed_dim, prompt_dim, self.embed_dim, top_k)
        self.vit = timm.create_model('vit_base_patch16_224.augreg2_in21k_ft_in1k', pretrained=True, num_classes=num_classes)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.total_prompt_dim = top_k * prompt_dim

        # Freeze the ViT model (both query_function and vit)
        for param in self.query_function.parameters():
            param.requires_grad = False
        for param in self.vit.parameters():
            param.requires_grad = False
        for param in self.vit.head.parameters():
            param.requires_grad = True

    def forward(self, x): # N x channels x H x W 
        queries = self.query_function.forward_features(x)[:, 0] # CLS Token (N x Embed_Dim)
        prompts, distance_sum = self.prompt_pool.select_prompt(queries) # N x (K * Lp) x Embed_Dim

        patch_embeddings = self.vit.patch_embed(x) # N x Token_Len x Embed_Dim
        cls_token = self.vit.cls_token.expand(patch_embeddings.shape[0], -1, -1) # N x 1 x Embed_Dim
        patch_embeddings = torch.cat((cls_token, patch_embeddings), dim=1) # N x (1 + Token_Len) x Embed_Dim
        embeddings = patch_embeddings + self.vit.pos_embed # N x (1 + Token_Len) x Embed_Dim

        tokens = torch.cat((prompts, embeddings), dim=1) # N x (K * Lp + 1 + Token_Len) x Embed_Dim

        tokens = self.vit.blocks(tokens) # N x (K * Lp + 1 + Token_Len) x Embed_Dim
        tokens = self.vit.norm(tokens)

        extracted_prompts = tokens[:, 0:self.total_prompt_dim, :] # N x K * Lp x Embed_Dim
        extracted_prompts = extracted_prompts.permute(0, 2, 1) # N x Embed_Dim x K * Lp

        pooled_prompts = self.pool(extracted_prompts).squeeze(-1) # N x Embed_Dim

        outputs = self.vit.head(pooled_prompts) # N x num_classes

        return outputs, distance_sum

In [6]:
INPUT_SIZE = 224
NUM_CLASSES = 100
NUM_TASKS = 10
CLASSES_PER_TASK = 10
BATCH_SIZE = 16
EPOCHS_PER_TASK = 5
LEARNING_RATE = 0.001875
POOL_SIZE = 10
PROMPT_DIM = 5
TOP_K = 5
CONSTANT = 0.1
MAX_NORM = 1.0
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [13]:
dataloaders_train, dataloaders_test, class_mask = get_dataloaders(INPUT_SIZE, NUM_CLASSES, NUM_TASKS, CLASSES_PER_TASK, BATCH_SIZE, False, device)
model = Model(POOL_SIZE, PROMPT_DIM, TOP_K, NUM_CLASSES).to(device)
criterion = CustomLoss(CONSTANT).to(device)

Files already downloaded and verified
Files already downloaded and verified


In [14]:
print('TRAINING MODEL!!!')

model.train(True)

for i in range(NUM_TASKS):
    dataloader = dataloaders_train[i]
    mask = class_mask[i]
    total_task = 0.0
    correct_task = 0.0

    optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

    for j in range(EPOCHS_PER_TASK):
        total_epoch = 0.0
        correct_epoch = 0.0

        for data in dataloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            logits, loss_value = model(inputs)
            loss_value = loss_value.to(device)

            not_mask = np.setdiff1d(np.arange(NUM_CLASSES), mask)
            not_mask = torch.tensor(not_mask, dtype=torch.int64).to(device)
            logits = logits.index_fill(dim=1, index=not_mask, value=float('-inf'))

            loss = criterion(logits, labels, loss_value)
            optim.zero_grad()
            nn.utils.clip_grad_norm_(model.parameters(), MAX_NORM)
            loss.backward()
            optim.step()

            _, predicted = torch.max(logits.data, 1)

            total_epoch += labels.size(0)
            correct_epoch += (predicted == labels).sum().item()

            print(f'For Task{i+1} Epoch {j+1}/{EPOCHS_PER_TASK}, Accuracy: {correct_epoch/total_epoch * 100:.2f}%')

TRAINING MODEL!!!
For Task1 Epoch 1/5, Accuracy: 6.25%
For Task1 Epoch 1/5, Accuracy: 9.38%
For Task1 Epoch 1/5, Accuracy: 10.42%
For Task1 Epoch 1/5, Accuracy: 14.06%
For Task1 Epoch 1/5, Accuracy: 16.25%
For Task1 Epoch 1/5, Accuracy: 21.88%
For Task1 Epoch 1/5, Accuracy: 26.79%
For Task1 Epoch 1/5, Accuracy: 28.91%
For Task1 Epoch 1/5, Accuracy: 32.64%
For Task1 Epoch 1/5, Accuracy: 34.38%
For Task1 Epoch 1/5, Accuracy: 35.23%
For Task1 Epoch 1/5, Accuracy: 35.94%
For Task1 Epoch 1/5, Accuracy: 38.94%
For Task1 Epoch 1/5, Accuracy: 42.86%
For Task1 Epoch 1/5, Accuracy: 44.58%
For Task1 Epoch 1/5, Accuracy: 47.66%
For Task1 Epoch 1/5, Accuracy: 50.00%
For Task1 Epoch 1/5, Accuracy: 51.74%
For Task1 Epoch 1/5, Accuracy: 52.30%
For Task1 Epoch 1/5, Accuracy: 52.19%
For Task1 Epoch 1/5, Accuracy: 53.27%
For Task1 Epoch 1/5, Accuracy: 54.83%
For Task1 Epoch 1/5, Accuracy: 55.71%
For Task1 Epoch 1/5, Accuracy: 57.03%
For Task1 Epoch 1/5, Accuracy: 57.00%
For Task1 Epoch 1/5, Accuracy: 57.

In [15]:
print('TESTING MODEL!!!')

model.eval()
with torch.no_grad():
    total_total = 0.0
    correct_total = 0.0

    for i in range(NUM_TASKS):
        dataloader = dataloaders_test[i]
        total_task = 0.0
        correct_task = 0.0

        for data in dataloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            outputs, distance_sum = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total_task += labels.size(0)
            correct_task += (predicted == labels).sum().item()

        
        print(f'For Task{i+1} Accuracy: {correct_task/total_task * 100:.2f}%')
        
        total_total += total_task
        correct_total += correct_task

    print(f'Overall Accuracy: {correct_total/total_total * 100:.2f}%')

TESTING MODEL!!!
For Task1 Accuracy: 75.30%
For Task2 Accuracy: 80.30%
For Task3 Accuracy: 85.90%
For Task4 Accuracy: 82.90%
For Task5 Accuracy: 85.60%
For Task6 Accuracy: 74.00%
For Task7 Accuracy: 80.00%
For Task8 Accuracy: 83.60%
For Task9 Accuracy: 90.40%
For Task10 Accuracy: 87.60%
Overall Accuracy: 82.56%
