In this notebook, we are going to show how to fine-tune the Perceiver for image classification.

For more info regarding the Perceiver, I refer to:

* the Transformers docs: https://huggingface.co/docs/transformers/model_doc/perceiver
* the blog post: https://huggingface.co/blog/perceiver

## Set-up environment

We first install HuggingFace Transformers & Datasets.

In [1]:
!pip install -q transformers datasets

## Load data

Here we load a small portion of the CIFAR-10 dataset, for demonstration purposes.

In [2]:
from datasets import load_dataset

# load cifar10 (only small portion for demonstration purposes) 
train_ds, test_ds = load_dataset('cifar10', split=['train[:1000]', 'test[:100]'])
# split up training into training + validation
splits = train_ds.train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']



  0%|          | 0/2 [00:00<?, ?it/s]

We'll define the id2label and label2id dictionaries, as these will be useful when doing inference.

In [3]:
id2label = {idx:label for idx,label in enumerate(train_ds.features['label'].names)}
label2id = {label:idx for idx, label in id2label.items()}
print(id2label)
print(label2id)

{0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}
{'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}


We can prepare the data for the model using the feature extractor.

Note that this feature extractor is fairly basic: it will just do center cropping + resizing + normalizing of the color channels.

One should actually add several data augmentations (available in libraries like [torchvision](https://pytorch.org/vision/stable/transforms.html) and [albumentations](https://albumentations.ai/) to achieve greater results. I refer to my [ViT notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_the_%F0%9F%A4%97_Trainer.ipynb) for an example.

In [4]:
from transformers import PerceiverFeatureExtractor

feature_extractor = PerceiverFeatureExtractor()

Note that HuggingFace Datasets has an Image feature, meaning that every image is a PIL (Pillow) image by default. The feature extractor will turn each Pillow image into a PyTorch tensor of shape (3, 224, 224).

Note that Apache Arrow (which HuggingFace Datasets uses as a back-end) doesn't know PyTorch Tensors, but we can escape it by using the `set_transform` method on the Dataset, which allows to only prepare images when we need them (i.e. on-the-fly). This is awesome as it saves memory! Refer to the [docs](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.set_transform) for more information.

In [5]:
import numpy as np

def preprocess_images(examples):
    examples['pixel_values'] = feature_extractor(examples['img'], return_tensors="pt").pixel_values
    return examples

In [6]:
# Set the transforms
train_ds.set_transform(preprocess_images)
val_ds.set_transform(preprocess_images)
test_ds.set_transform(preprocess_images)

We can now load preprocessed images (on-the-fly) as follows:

In [7]:
train_ds[:2]

{'img': [<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32 at 0x7F2702040750>,
  <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32 at 0x7F2701FA4850>],
 'label': [7, 9],
 'pixel_values': tensor([[[[ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
           [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
           [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
           ...,
           [-1.5357, -1.5357, -1.5357,  ...,  2.2318,  2.2318,  2.2318],
           [-1.5357, -1.5357, -1.5357,  ...,  2.2318,  2.2318,  2.2318],
           [-1.5528, -1.5528, -1.5528,  ...,  2.2318,  2.2318,  2.2318]],
 
          [[ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
           [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
           [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
           ...,
           [-1.4055, -1.4055, -1.4055,  ...,  2.4111,  2.4111,  2.4111],
           [-1.4055, -1.4055, -

It's very easy to create corresponding PyTorch DataLoaders, like so:

In [8]:
from torch.utils.data import DataLoader
import torch

device = torch.device('cuda')

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

train_batch_size = 10
eval_batch_size = 10

train_dataloader = DataLoader(train_ds, shuffle=True, collate_fn=collate_fn, batch_size=train_batch_size)
val_dataloader = DataLoader(val_ds, collate_fn=collate_fn, batch_size=eval_batch_size)
test_dataloader = DataLoader(test_ds, collate_fn=collate_fn, batch_size=eval_batch_size)

We can verify our data a bit:

In [9]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k, v.shape)

pixel_values torch.Size([10, 3, 224, 224])
labels torch.Size([10])


Some more verification:

In [10]:
assert batch['pixel_values'].shape == (train_batch_size, 3, 224, 224)
assert batch['labels'].shape == (train_batch_size,)

In [11]:
next(iter(val_dataloader))['pixel_values'].shape

torch.Size([10, 3, 224, 224])

## Define model

Here we only replace the final projection layer of the decoder (`PerceiverClassificationDecoder`) of the checkpoint that was trained on ImageNet. This means that we will use the same (learned) output queries as before, hence the cross-attention operation will give the same output. However, the final projection layer has 1000 output neurons during pre-training, while we only have 10.

NOTE: note that the Perceiver has 3 variants for image classification:
* PerceiverForImageClassificationLearned
* PerceiverForImageClassificationFourier
* PerceiverForImageClassificationConvProcessing.

Here I'm using the first one, which adds learned 1D position embeddings to the pixel values. Note that the best results will be obtained with the latter.

For in-depth understanding on how the Perceiver works, I refer to my [blog post](https://huggingface.co/blog/perceiver).

We can use the handy `ignore_mismatched_sizes` to replace the head. We also set the `id2label` and `label2id` mappings we defined earlier (which will be handy when doing inference).

In [12]:
from transformers import PerceiverConfig, PerceiverTokenizer, PerceiverFeatureExtractor, PerceiverModel
from transformers.models.perceiver.modeling_perceiver import (
    PerceiverTextPreprocessor,
    PerceiverImagePreprocessor,
    PerceiverClassificationDecoder,
)
import torch
import requests
from PIL import Image
device = torch.device('cuda')
# ("cuda" if torch.cuda.is_available() else "cpu")

# model = PerceiverForImageClassificationLearned.from_pretrained("deepmind/vision-perceiver-learned",
#                                                                num_labels=10,
#                                                                id2label=id2label,
#                                                                label2id=label2id,
#                                                                ignore_mismatched_sizes=True)


# EXAMPLE 2: using the Perceiver to classify images
# - we define an ImagePreprocessor, which can be used to embed images
config = PerceiverConfig(image_size=224,
                          use_labels=True,
                          num_labels=10,
                          id2label=id2label,
                          label2id=label2id,
                          ignore_mismatched_sizes=True
                         )
preprocessor = PerceiverImagePreprocessor(
    config,
    prep_type="conv1x1",
    spatial_downsample=1,
    out_channels=256,
    position_encoding_type="trainable",
    concat_or_add_pos="concat",
    project_pos_dim=256,
    trainable_position_encoding_kwargs=dict(
        num_channels=256,
        index_dims=config.image_size**2,
    ),
)

model = PerceiverModel(
    config,
    input_preprocessor=preprocessor,
    decoder=PerceiverClassificationDecoder(
        config,
        num_channels=config.d_latents,
        trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1),
        use_query_residual=True,
    ),
)
# .from_pretrained("deepmind/vision-perceiver-learned",
#                                                                num_labels=10,
#                                                                id2label=id2label,
#                                                                label2id=label2id,
#                                                                ignore_mismatched_sizes=True)

model.to(device)

PerceiverModel(
  (input_preprocessor): PerceiverImagePreprocessor(
    (convnet_1x1): Conv2d(3, 256, kernel_size=(1, 1), stride=(1, 1))
    (position_embeddings): PerceiverTrainablePositionEncoding()
    (positions_projection): Linear(in_features=256, out_features=256, bias=True)
    (conv_after_patches): Identity()
  )
  (embeddings): PerceiverEmbeddings()
  (encoder): PerceiverEncoder(
    (cross_attention): PerceiverLayer(
      (attention): PerceiverAttention(
        (self): PerceiverSelfAttention(
          (layernorm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (layernorm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (query): Linear(in_features=1280, out_features=512, bias=True)
          (key): Linear(in_features=512, out_features=512, bias=True)
          (value): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (output): PerceiverSelfOutput(
          (dense):

## Train the model

Here we train the model using native PyTorch.

In [13]:
from IPython.display import clear_output

from transformers import AdamW
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

# torch.cuda.empty_cache()
# print(torch.cuda.memory_summary(device=None, abbreviated=False))

optimizer = AdamW(model.parameters(), lr=5e-5)

model.train()
for epoch in range(10):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    for batch in tqdm(train_dataloader):
         # get the inputs; 
         inputs = batch["pixel_values"].to(device)
         labels = batch["labels"].to(device)

         # zero the parameter gradients
         optimizer.zero_grad()

         # forward + backward + optimize
         outputs = model(inputs=inputs)
         logits = outputs.logits

        #  print('logits: ', logits.shape)
        #  print('logits shape: ', logits)
        #  print('outputs: ', outputs)
        #  print('labels: ', labels)
         # [1, 2]

         # to train, one can train the model using standard cross-entropy:
         criterion = torch.nn.CrossEntropyLoss()

         # labels = torch.tensor([1,1])
         loss = criterion(logits, labels)
        #  print('loss: ', loss)
         #  print(outputs)


         
         loss.backward()
         optimizer.step()

        #  clear_output(wait=True)
        #  predictions = outputs.logits.argmax(-1).cpu().detach().numpy()
        #  accuracy = accuracy_score(y_true=batch["labels"].numpy(), y_pred=predictions)
        #  print(f"Loss: {loss.item()}, Accuracy: {accuracy}")

        #  # evaluate
    predictions = outputs.logits.argmax(-1).cpu().detach().numpy()
    accuracy = accuracy_score(y_true=batch["labels"].numpy(), y_pred=predictions)
    print(f"Loss: {loss.item()}, Accuracy: {accuracy}")
    
    

Epoch: 0




  0%|          | 0/90 [00:00<?, ?it/s]

Loss: 2.20711612701416, Accuracy: 0.1
Epoch: 1


  0%|          | 0/90 [00:00<?, ?it/s]

Loss: 2.7028424739837646, Accuracy: 0.0
Epoch: 2


  0%|          | 0/90 [00:00<?, ?it/s]

Loss: 1.9801855087280273, Accuracy: 0.4
Epoch: 3


  0%|          | 0/90 [00:00<?, ?it/s]

Loss: 1.9687011241912842, Accuracy: 0.5
Epoch: 4


  0%|          | 0/90 [00:00<?, ?it/s]

Loss: 2.119823694229126, Accuracy: 0.2
Epoch: 5


  0%|          | 0/90 [00:00<?, ?it/s]

Loss: 1.8747456073760986, Accuracy: 0.5
Epoch: 6


  0%|          | 0/90 [00:00<?, ?it/s]

Loss: 2.4442341327667236, Accuracy: 0.1
Epoch: 7


  0%|          | 0/90 [00:00<?, ?it/s]

Loss: 2.175804615020752, Accuracy: 0.2
Epoch: 8


  0%|          | 0/90 [00:00<?, ?it/s]

Loss: 2.0868992805480957, Accuracy: 0.5
Epoch: 9


  0%|          | 0/90 [00:00<?, ?it/s]

Loss: 1.8617359399795532, Accuracy: 0.4


## Evaluate the model

Finally, we evaluate the model on the test set. We use the Datasets library to compute the accuracy.

On some runs, I got 78%, then 66%. Of course, one would need to train on the entire dataset to achieve great results.

In [14]:
from tqdm.notebook import tqdm
from datasets import load_metric

accuracy = load_metric("accuracy")

model.eval()
for batch in tqdm(val_dataloader):
      # get the inputs; 
      inputs = batch["pixel_values"].to(device)
      labels = batch["labels"].to(device)

      # forward pass
      outputs = model(inputs=inputs)
      logits = outputs.logits 
      predictions = logits.argmax(-1).cpu().detach().numpy()
      references = batch["labels"].numpy()
      accuracy.add_batch(predictions=predictions, references=references)

final_score = accuracy.compute()
print("Accuracy on test set:", final_score)

  after removing the cwd from sys.path.


  0%|          | 0/10 [00:00<?, ?it/s]

Accuracy on test set: {'accuracy': 0.22}


In [15]:
# from transformers import PerceiverConfig, PerceiverTokenizer, PerceiverFeatureExtractor, PerceiverModel, PerceiverForImageClassificationLearned
# from transformers.models.perceiver.modeling_perceiver import (
#     PerceiverTextPreprocessor,
#     PerceiverImagePreprocessor,
#     PerceiverClassificationDecoder,
# )
# import torch
# import requests
# from PIL import Image
# device = torch.device('cuda')
# # ("cuda" if torch.cuda.is_available() else "cpu")

# model = PerceiverForImageClassificationLearned.from_pretrained("deepmind/vision-perceiver-learned",
#                                                                num_labels=10,
#                                                                id2label=id2label,
#                                                                label2id=label2id,
#                                                                ignore_mismatched_sizes=True)

# model

In [16]:


# # EXAMPLE 2: using the Perceiver to classify images
# # - we define an ImagePreprocessor, which can be used to embed images
# config = PerceiverConfig(image_size=224,
#                           use_labels=True,
#                           num_labels=10,
#                           id2label=id2label,
#                           label2id=label2id,
#                           ignore_mismatched_sizes=True,
                         
#                          )
# preprocessor = PerceiverImagePreprocessor(
#     config,
#     prep_type="conv1x1",
#     spatial_downsample=1,
#     out_channels=256,
#     position_encoding_type="trainable",
#     concat_or_add_pos="concat",
#     project_pos_dim=256,
#     trainable_position_encoding_kwargs=dict(
#         num_channels=256,
#         index_dims=config.image_size**2,
#     ),
# )

# model = PerceiverModel(
#     config,
#     input_preprocessor=preprocessor,
#     decoder=PerceiverClassificationDecoder(
#         config,
#         num_channels=config.d_latents,
#         trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1),
#         use_query_residual=True,
#     ),
# )

# model