<center><h1>Fine-tuning Image Transformers using Learnable Memory</h1></center>

In [1]:
%matplotlib inline
# Built-in IPython extension to reload modules when updated.
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import transformers
from transformers import ViTModel, ViTConfig, ViTForImageClassification
from tqdm.auto import tqdm
import os
import matplotlib.pyplot as plt
from IPython.display import clear_output
from copy import deepcopy
import random
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR

device = "cuda:0"
seed = 42
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["PYTHONHASHSEED"] = str(seed)
print(f"Random seed set as {seed}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Directories for cache and datasets
home_dir = "/hdd/ege"
cache_dir = os.path.join(home_dir, "ceng502")
datasets_dir = os.path.join(home_dir, "datasets")

# Uncomment if you don't want to see warnings
# transformers.logging.set_verbosity_error()

2023-05-24 10:40:25.776483: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-05-24 10:40:26.296891: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2023-05-24 10:40:26.297043: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory


Random seed set as 42


In [3]:
def pretrained_model(
        model_name = 'google/vit-base-patch32-224-in21k',
        num_classes = 100,
        ):
    config = ViTConfig.from_pretrained(model_name, num_labels=num_classes, cache_dir=cache_dir)
    model = ViTForImageClassification.from_pretrained(model_name, config=config, cache_dir=cache_dir)
    model = model.to(device)
    return model

# Datasets

In [4]:
# Load and preprocess the dataset
transform = transforms.Compose([
    transforms.RandomResizedCrop((224)),
    transforms.ToTensor(),
    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

batch_size = 64

# CIFAR100
cifar100_train_dataset = datasets.CIFAR100(root=datasets_dir, train=True, transform=transform, download=True)
cifar100_train_loader = DataLoader(dataset=cifar100_train_dataset, batch_size=batch_size, shuffle=True)
cifar100_validation_dataset = datasets.CIFAR100(root=datasets_dir, train=False, transform=transform, download=True)
cifar100_validation_loader = DataLoader(dataset=cifar100_validation_dataset, batch_size=batch_size, shuffle=False)

# MNIST
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
places_train_dataset = datasets.Places365(root=datasets_dir,small=True, split="train-standard", transform=transform)#, download=True)
places_train_loader = DataLoader(dataset=places_train_dataset, batch_size=batch_size, shuffle=True)
places_validation_dataset = datasets.Places365(root=datasets_dir,small=True, split="val", transform=transform)#, download=True)
places_validation_loader = DataLoader(dataset=places_validation_dataset, batch_size=batch_size, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
# Define your loss function
criterion = nn.CrossEntropyLoss()

# Define number of steps and warmup steps
total_steps = 20
warmup_steps = 5

# Training and Validation Code

In [6]:
# Linear warmup
def warmup_linear(step):
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps))
    return 1.0


def train(model, parameters,
          dataloader, valid_dataloader,
          output_head=None,
         total_steps = 20):
    
    model.train()
    
    # SGD with Momentum optimizer
    optimizer = optim.SGD(parameters, lr=0.1, momentum=0.9)
    # Cosine learning rate schedule
    cosine_scheduler = CosineAnnealingLR(optimizer, T_max=total_steps)
    
    warmup_scheduler = LambdaLR(optimizer, warmup_linear)

    train_losses, valid_accuracy = [], []
    for step in tqdm(range(total_steps), leave=False):
        train_loss = 0.0
        for batch_idx, (data, targets) in enumerate(tqdm(dataloader, leave=False)):
            data = data.to(device)
            targets = targets.to(device)

            # Forward pass
            outputs = model(data)
            if output_head is not None:
                outputs = outputs[output_head]
            loss = criterion(outputs.logits, targets)
            loss_val = loss.detach().cpu().item()
            train_loss += loss_val
    
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping
            clip_grad_norm_(parameters, max_norm=1.0)
            
            optimizer.step()
            
            # Update learning rate
            if step < warmup_steps:
                warmup_scheduler.step()
            else:
                cosine_scheduler.step()
                
        epoch_loss = train_loss/len(dataloader)
        print(f"step {step} loss is {epoch_loss:.4f}")
        train_losses.append(epoch_loss)
        
        valid_acc = validate(model, valid_dataloader, output_head)
        print(f"step {step} valid acc is {valid_acc:.2f}")
        valid_accuracy.append(valid_acc)

        #print(f"Epoch {epoch + 1}/{num_epochs} loss: {train_loss}")
    return train_losses, valid_accuracy
        

In [7]:
def validate(model, dataloader, output_head=None):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for data, targets in tqdm(dataloader, leave=False):
            data = data.to(device)
            targets = targets.to(device)

            outputs = model(data)
            if output_head is not None:
                outputs = outputs[output_head]
            _, predicted = torch.max(outputs.logits, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    
    return correct / total

# Full Finetuning

plt.plot(losses)
        plt.show()

In [8]:
model = pretrained_model()
full_train, full_val = train(model, model.parameters(), cifar100_train_loader, cifar100_validation_loader)

Some weights of the model checkpoint at google/vit-base-patch32-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch32-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

step 0 loss is 4.5839


  0%|          | 0/782 [00:00<?, ?it/s]

step 0 valid acc is 0.03


  0%|          | 0/782 [00:00<?, ?it/s]

step 1 loss is 3.4103


  0%|          | 0/782 [00:00<?, ?it/s]

step 1 valid acc is 0.37


  0%|          | 0/782 [00:00<?, ?it/s]

step 2 loss is 2.0501


  0%|          | 0/782 [00:00<?, ?it/s]

step 2 valid acc is 0.52


  0%|          | 0/782 [00:00<?, ?it/s]

step 3 loss is 1.7157


  0%|          | 0/782 [00:00<?, ?it/s]

step 3 valid acc is 0.58


  0%|          | 0/782 [00:00<?, ?it/s]

step 4 loss is 1.5989


  0%|          | 0/782 [00:00<?, ?it/s]

step 5 loss is 1.1882


  0%|          | 0/782 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



step 10 loss is 0.8615


  0%|          | 0/782 [00:00<?, ?it/s]

step 10 valid acc is 0.76


  0%|          | 0/782 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



step 15 valid acc is 0.82


  0%|          | 0/782 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [9]:
val_full = validate(model, cifar100_validation_loader)

  0%|          | 0/157 [00:00<?, ?it/s]

# Class+Head Only Finetuning


In [8]:
model = pretrained_model()
parameters = [model.vit.embeddings.cls_token] + list(model.classifier.parameters())

Some weights of the model checkpoint at google/vit-base-patch32-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch32-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


for name, param in model.named_parameters():
    if ("classifier" not in name) and ("cls_token" not in name) :
        param.requires_grad = False

In [9]:
train_losses_classhead, val_acc_classhead = train(model, parameters, cifar100_train_loader, cifar100_validation_loader)


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

step 0 loss is 1.8996


  0%|          | 0/157 [00:00<?, ?it/s]

step 0 valid acc is 0.64


  0%|          | 0/782 [00:00<?, ?it/s]

step 1 loss is 1.3328


  0%|          | 0/157 [00:00<?, ?it/s]

step 1 valid acc is 0.65


  0%|          | 0/782 [00:00<?, ?it/s]

step 2 loss is 1.1957


  0%|          | 0/157 [00:00<?, ?it/s]

step 2 valid acc is 0.66


  0%|          | 0/782 [00:00<?, ?it/s]

step 3 loss is 1.1122


  0%|          | 0/157 [00:00<?, ?it/s]

step 3 valid acc is 0.67


  0%|          | 0/782 [00:00<?, ?it/s]

step 4 loss is 1.0499


  0%|          | 0/157 [00:00<?, ?it/s]

step 4 valid acc is 0.67


  0%|          | 0/782 [00:00<?, ?it/s]

step 5 loss is 0.9880


  0%|          | 0/157 [00:00<?, ?it/s]

step 5 valid acc is 0.67


  0%|          | 0/782 [00:00<?, ?it/s]

step 6 loss is 0.9666


  0%|          | 0/157 [00:00<?, ?it/s]

step 6 valid acc is 0.67


  0%|          | 0/782 [00:00<?, ?it/s]

step 7 loss is 0.9413


  0%|          | 0/157 [00:00<?, ?it/s]

step 7 valid acc is 0.67


  0%|          | 0/782 [00:00<?, ?it/s]

step 8 loss is 0.9246


  0%|          | 0/157 [00:00<?, ?it/s]

step 8 valid acc is 0.67


  0%|          | 0/782 [00:00<?, ?it/s]

step 9 loss is 0.9046


  0%|          | 0/157 [00:00<?, ?it/s]

step 9 valid acc is 0.67


  0%|          | 0/782 [00:00<?, ?it/s]

step 10 loss is 0.8908


  0%|          | 0/157 [00:00<?, ?it/s]

step 10 valid acc is 0.67


  0%|          | 0/782 [00:00<?, ?it/s]

step 11 loss is 0.8746


  0%|          | 0/157 [00:00<?, ?it/s]

step 11 valid acc is 0.67


  0%|          | 0/782 [00:00<?, ?it/s]

step 12 loss is 0.8616


  0%|          | 0/157 [00:00<?, ?it/s]

step 12 valid acc is 0.67


  0%|          | 0/782 [00:00<?, ?it/s]

step 13 loss is 0.8477


  0%|          | 0/157 [00:00<?, ?it/s]

step 13 valid acc is 0.67


  0%|          | 0/782 [00:00<?, ?it/s]

step 14 loss is 0.8341


  0%|          | 0/157 [00:00<?, ?it/s]

step 14 valid acc is 0.67


  0%|          | 0/782 [00:00<?, ?it/s]

step 15 loss is 0.8231


  0%|          | 0/157 [00:00<?, ?it/s]

step 15 valid acc is 0.67


  0%|          | 0/782 [00:00<?, ?it/s]

step 16 loss is 0.8100


  0%|          | 0/157 [00:00<?, ?it/s]

step 16 valid acc is 0.67


  0%|          | 0/782 [00:00<?, ?it/s]

step 17 loss is 0.8027


  0%|          | 0/157 [00:00<?, ?it/s]

step 17 valid acc is 0.67


  0%|          | 0/782 [00:00<?, ?it/s]

step 18 loss is 0.7898


  0%|          | 0/157 [00:00<?, ?it/s]

step 18 valid acc is 0.67


  0%|          | 0/782 [00:00<?, ?it/s]

step 19 loss is 0.7812


  0%|          | 0/157 [00:00<?, ?it/s]

step 19 valid acc is 0.67


In [10]:
val_acc_classhead

[0.6383,
 0.6482,
 0.6596,
 0.6656,
 0.6683,
 0.6715,
 0.6699,
 0.6742,
 0.6716,
 0.6734,
 0.6723,
 0.6735,
 0.6723,
 0.6714,
 0.672,
 0.6723,
 0.67,
 0.6711,
 0.6702,
 0.6718]

In [11]:
class_val = validate(model, cifar100_validation_loader)

  0%|          | 0/157 [00:00<?, ?it/s]

In [12]:
class_val

0.6718

This will be used as a base model for the next experiment.

In [13]:
base_model = model

# Memory Token

First we convert our model to a `MemoryCapableViT`. This makes it possible to add new classification heads with memory tokens. It also takes care of the attention masking.

In [14]:
from vit import MemoryCapableViT
model = MemoryCapableViT(deepcopy(base_model))

Random seed set as 42


By default, wrapping a `ViTForImageClassification` into `MemoryCapableViT` doesn't change anything apart from some under-the-hood modifications (e.g. class token is inserted at the end instead of the beginning).

Let's check whether they are actually equivalent by running the validation again. Note that since `MemoryCapableViT` can have multiple heads, we need to specify which head's output to use. Since we currently have only one head, its index is 0.

In [15]:
memory_val = validate(model, cifar100_validation_loader, output_head=0)

  0%|          | 0/157 [00:00<?, ?it/s]

In [19]:
memory_val_2 = validate(model, cifar100_validation_loader, output_head=0) #after new head

  0%|          | 0/157 [00:00<?, ?it/s]

In [16]:
memory_val

0.6718

In [20]:
memory_val_2

0.6718

As expected, this accuracy value is the same as before.

For a more rigorous verification, there's a unit test in `test_vit.py` which checks the value of output. You can run all available unit test with `pytest`.

## Add new classification head with memory

We can now add a new classification head to our model.

We will train the new head for CIFAR100 which has 100 classes. There will be 4 memory tokens in each self-attention layer.

In [17]:
new_parameters = model.add_head(memory_tokens=1, num_classes=100)

The new parameters are returned as a list:
- The new class token. Shape: `[1, 1, 768]`.
- All memory tokens, one for each self-attention layer. We have 12 self-attention layers. Each has shape `[1, memory_tokens, 768]`.
- Weights (of size `[100, 768]`) and biases (of size `[100]`) of the new classifier head.

Let's print shapes of these parameters:

In [18]:
[p.size() for p in new_parameters]

[torch.Size([1, 1, 768]),
 torch.Size([1, 1, 768]),
 torch.Size([1, 1, 768]),
 torch.Size([1, 1, 768]),
 torch.Size([1, 1, 768]),
 torch.Size([1, 1, 768]),
 torch.Size([1, 1, 768]),
 torch.Size([1, 1, 768]),
 torch.Size([1, 1, 768]),
 torch.Size([1, 1, 768]),
 torch.Size([1, 1, 768]),
 torch.Size([1, 1, 768]),
 torch.Size([1, 1, 768]),
 torch.Size([100, 768]),
 torch.Size([100])]

After calling the `add_head` method, the attention mask will be updated automatically. This makes sure that the old class tokens don't interact with the new class and memory tokens.

<img src="images/attention-mask.png" alt="Attention Mask Figure" style="width: 500px;"/>

Let's check whether our attention mask matches the table above. Note that we currently have $\text{INP}$, $\text{CLS}$, $C_1$ and $M_1$ in our network.

In [21]:
model.vit.encoder.layer[0].attention.attention.attention_mask[0, 48:, 48:]

tensor([[0., 0., -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., 0.]], device='cuda:0')

This attention mask is added to the computed attention scores before the softmax is applied.

In self-attention layer, we don't insert the memory tokens while calculating the query. Therefore, memory tokens will not attend to other tokens and they won't be present in the output of the self-attention. This also ensures that the attention scores matrix has the same shape as the attention mask.

Let's train the new parameters:

In [22]:
memory_train, memory_validate = train(model, new_parameters, cifar100_train_loader,cifar100_validation_loader, output_head=1)


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

step 0 loss is 1.8027


  0%|          | 0/157 [00:00<?, ?it/s]

step 0 valid acc is 0.66


  0%|          | 0/782 [00:00<?, ?it/s]

step 1 loss is 1.1985


  0%|          | 0/157 [00:00<?, ?it/s]

step 1 valid acc is 0.67


  0%|          | 0/782 [00:00<?, ?it/s]

step 2 loss is 1.0663


  0%|          | 0/157 [00:00<?, ?it/s]

step 2 valid acc is 0.68


  0%|          | 0/782 [00:00<?, ?it/s]

step 3 loss is 0.9816


  0%|          | 0/157 [00:00<?, ?it/s]

step 3 valid acc is 0.69


  0%|          | 0/782 [00:00<?, ?it/s]

step 4 loss is 0.9188


  0%|          | 0/157 [00:00<?, ?it/s]

step 4 valid acc is 0.69


  0%|          | 0/782 [00:00<?, ?it/s]

step 5 loss is 0.8497


  0%|          | 0/157 [00:00<?, ?it/s]

step 5 valid acc is 0.69


  0%|          | 0/782 [00:00<?, ?it/s]

step 6 loss is 0.8276


  0%|          | 0/157 [00:00<?, ?it/s]

step 6 valid acc is 0.69


  0%|          | 0/782 [00:00<?, ?it/s]

step 7 loss is 0.8059


  0%|          | 0/157 [00:00<?, ?it/s]

step 7 valid acc is 0.69


  0%|          | 0/782 [00:00<?, ?it/s]

step 8 loss is 0.7882


  0%|          | 0/157 [00:00<?, ?it/s]

step 8 valid acc is 0.69


  0%|          | 0/782 [00:00<?, ?it/s]

step 9 loss is 0.7699


  0%|          | 0/157 [00:00<?, ?it/s]

step 9 valid acc is 0.69


  0%|          | 0/782 [00:00<?, ?it/s]

step 10 loss is 0.7542


  0%|          | 0/157 [00:00<?, ?it/s]

step 10 valid acc is 0.69


  0%|          | 0/782 [00:00<?, ?it/s]

step 11 loss is 0.7388


  0%|          | 0/157 [00:00<?, ?it/s]

step 11 valid acc is 0.69


  0%|          | 0/782 [00:00<?, ?it/s]

step 12 loss is 0.7244


  0%|          | 0/157 [00:00<?, ?it/s]

step 12 valid acc is 0.69


  0%|          | 0/782 [00:00<?, ?it/s]

step 13 loss is 0.7113


  0%|          | 0/157 [00:00<?, ?it/s]

step 13 valid acc is 0.69


  0%|          | 0/782 [00:00<?, ?it/s]

step 14 loss is 0.6979


  0%|          | 0/157 [00:00<?, ?it/s]

step 14 valid acc is 0.69


  0%|          | 0/782 [00:00<?, ?it/s]

step 15 loss is 0.6866


  0%|          | 0/157 [00:00<?, ?it/s]

step 15 valid acc is 0.69


  0%|          | 0/782 [00:00<?, ?it/s]

step 16 loss is 0.6742


  0%|          | 0/157 [00:00<?, ?it/s]

step 16 valid acc is 0.69


  0%|          | 0/782 [00:00<?, ?it/s]

step 17 loss is 0.6641


  0%|          | 0/157 [00:00<?, ?it/s]

step 17 valid acc is 0.69


  0%|          | 0/782 [00:00<?, ?it/s]

step 18 loss is 0.6529


  0%|          | 0/157 [00:00<?, ?it/s]

step 18 valid acc is 0.69


  0%|          | 0/782 [00:00<?, ?it/s]

step 19 loss is 0.6434


  0%|          | 0/157 [00:00<?, ?it/s]

step 19 valid acc is 0.69


In [23]:
mem_val = validate(model, cifar100_validation_loader, output_head=1)

  0%|          | 0/157 [00:00<?, ?it/s]

In [24]:
mem_val

0.6854

The performance of the previous head should not be affected thanks to attention masking.

In [25]:
check_val = validate(model, cifar100_validation_loader, output_head=0)

  0%|          | 0/157 [00:00<?, ?it/s]

In [26]:
check_val

0.6718

# Model Concatenation

Suppose that someone else took the same pretrained network and fine-tuned it on another dataset with memory tokens.

In [None]:
model2 = MemoryCapableViT(deepcopy(base_model))
new_parameters = model2.add_head(memory_tokens=1, num_classes=365)


In [29]:
memory_train_places, memory_val_places = train(model2, new_parameters, places_train_loader, output_head=1, total_steps = 19)


  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/28180 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [20]:
validate(model2, mnist_validation_loader, output_head=1)

  0%|          | 0/157 [00:00<?, ?it/s]

0.9281

Normally, models exhibit lower performance on the previous dataset after finetuning on a different dataset and separately finetuned models cannot be combined. However, we can achieve all of these with learnable memory method!

<img src="images/model-concat.png" alt="Model Concatenation Figure" style="width: 600px;"/>

`MemoryCapableViT` offers `concatenate` method. It merges two separately finetuned models. This method operates on the model in-place to use less memory.

In [21]:
model.concatenate(model2)

The combined model has 3 heads: the first one is the original, trained on CIFAR10; the second one trained on CIFAR100; and finally the third one for MNIST. We can now accoplish all of these tasks with a single model without any performance penalty.

In [22]:
validate(model, cifar10_validation_loader, output_head=0)

  0%|          | 0/157 [00:00<?, ?it/s]

0.9305

In [23]:
validate(model, cifar100_validation_loader, output_head=1)

  0%|          | 0/157 [00:00<?, ?it/s]

0.7922

In [25]:
validate(model, mnist_validation_loader, output_head=2)

  0%|          | 0/157 [00:00<?, ?it/s]

0.9281

If we inspect the attention mask of the combined model, we should see Table 1 with $\text{INP}$, $\text{CLS}$, $C_1$, $M_1$, $C_2$ and $M_2$ columns.

In [26]:
model.vit.encoder.layer[0].attention.attention.attention_mask[0, 48:, 48:]

tensor([[0., 0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, 0., -inf],
        [0., 0., -inf, 0., -inf, 0.]], device='cuda:0')