In this notebook, we are going to show how to fine-tune the Perceiver for image classification.

For more info regarding the Perceiver, I refer to:

* the Transformers docs: https://huggingface.co/docs/transformers/model_doc/perceiver
* the blog post: https://huggingface.co/blog/perceiver

## Set-up environment

We first install HuggingFace Transformers & Datasets.

In [1]:
!pip install -q transformers datasets

## Load data

Here we load a small portion of the CIFAR-10 dataset, for demonstration purposes.

In [4]:
from datasets import load_dataset

# load cifar10 (only small portion for demonstration purposes) 
train_ds, test_ds = load_dataset('cifar10', split=['train[:1000]', 'test[:100]'])
# split up training into training + validation
splits = train_ds.train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']

Found cached dataset cifar10 (/Users/david/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4)


  0%|          | 0/2 [00:00<?, ?it/s]

We'll define the id2label and label2id dictionaries, as these will be useful when doing inference.

In [5]:
id2label = {idx:label for idx,label in enumerate(train_ds.features['label'].names)}
label2id = {label:idx for idx, label in id2label.items()}
print(id2label)
print(label2id)

{0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}
{'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}


We can prepare the data for the model using the feature extractor.

Note that this feature extractor is fairly basic: it will just do center cropping + resizing + normalizing of the color channels.

One should actually add several data augmentations (available in libraries like [torchvision](https://pytorch.org/vision/stable/transforms.html) and [albumentations](https://albumentations.ai/) to achieve greater results. I refer to my [ViT notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_the_%F0%9F%A4%97_Trainer.ipynb) for an example.

In [6]:
from transformers import PerceiverFeatureExtractor

feature_extractor = PerceiverFeatureExtractor()

Note that HuggingFace Datasets has an Image feature, meaning that every image is a PIL (Pillow) image by default. The feature extractor will turn each Pillow image into a PyTorch tensor of shape (3, 224, 224).

Note that Apache Arrow (which HuggingFace Datasets uses as a back-end) doesn't know PyTorch Tensors, but we can escape it by using the `set_transform` method on the Dataset, which allows to only prepare images when we need them (i.e. on-the-fly). This is awesome as it saves memory! Refer to the [docs](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.set_transform) for more information.

In [7]:
import numpy as np

def preprocess_images(examples):
    examples['pixel_values'] = feature_extractor(examples['img'], return_tensors="pt").pixel_values
    return examples

In [8]:
# Set the transforms
train_ds.set_transform(preprocess_images)
val_ds.set_transform(preprocess_images)
test_ds.set_transform(preprocess_images)

We can now load preprocessed images (on-the-fly) as follows:

In [9]:
train_ds[:2]

{'img': [<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,
  <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>],
 'label': [6, 7],
 'pixel_values': tensor([[[[-0.9216, -0.9216, -0.9216,  ..., -0.8824, -0.8824, -0.8824],
           [-0.9216, -0.9216, -0.9216,  ..., -0.8824, -0.8824, -0.8824],
           [-0.9216, -0.9216, -0.9216,  ..., -0.8824, -0.8824, -0.8824],
           ...,
           [ 0.4588,  0.4588,  0.4588,  ..., -0.2549, -0.2549, -0.2549],
           [ 0.4588,  0.4588,  0.4588,  ..., -0.2549, -0.2549, -0.2549],
           [ 0.4588,  0.4588,  0.4588,  ..., -0.2549, -0.2549, -0.2549]],
 
          [[-0.9216, -0.9216, -0.9216,  ..., -0.8824, -0.8824, -0.8824],
           [-0.9216, -0.9216, -0.9216,  ..., -0.8824, -0.8824, -0.8824],
           [-0.9216, -0.9216, -0.9216,  ..., -0.8824, -0.8824, -0.8824],
           ...,
           [ 0.2314,  0.2314,  0.2314,  ..., -0.4353, -0.4353, -0.4353],
           [ 0.2314,  0.2314,  0.2314,  ..., -0.4353, -0.4353, -0.4

It's very easy to create corresponding PyTorch DataLoaders, like so:

In [10]:
from torch.utils.data import DataLoader
import torch

device = torch.device('cuda')

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

train_batch_size = 10
eval_batch_size = 10

train_dataloader = DataLoader(train_ds, shuffle=True, collate_fn=collate_fn, batch_size=train_batch_size)
val_dataloader = DataLoader(val_ds, collate_fn=collate_fn, batch_size=eval_batch_size)
test_dataloader = DataLoader(test_ds, collate_fn=collate_fn, batch_size=eval_batch_size)

We can verify our data a bit:

In [11]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k, v.shape)

pixel_values torch.Size([10, 3, 224, 224])
labels torch.Size([10])


Some more verification:

In [12]:
assert batch['pixel_values'].shape == (train_batch_size, 3, 224, 224)
assert batch['labels'].shape == (train_batch_size,)

In [13]:
next(iter(val_dataloader))['pixel_values'].shape

torch.Size([10, 3, 224, 224])

## Define model

Here we only replace the final projection layer of the decoder (`PerceiverClassificationDecoder`) of the checkpoint that was trained on ImageNet. This means that we will use the same (learned) output queries as before, hence the cross-attention operation will give the same output. However, the final projection layer has 1000 output neurons during pre-training, while we only have 10.

NOTE: note that the Perceiver has 3 variants for image classification:
* PerceiverForImageClassificationLearned
* PerceiverForImageClassificationFourier
* PerceiverForImageClassificationConvProcessing.

Here I'm using the first one, which adds learned 1D position embeddings to the pixel values. Note that the best results will be obtained with the latter.

For in-depth understanding on how the Perceiver works, I refer to my [blog post](https://huggingface.co/blog/perceiver).

We can use the handy `ignore_mismatched_sizes` to replace the head. We also set the `id2label` and `label2id` mappings we defined earlier (which will be handy when doing inference).

In [12]:
from transformers import PerceiverConfig, PerceiverTokenizer, PerceiverFeatureExtractor, PerceiverModel
from transformers.models.perceiver.modeling_perceiver import (
    PerceiverTextPreprocessor,
    PerceiverImagePreprocessor,
    PerceiverClassificationDecoder,
)
import torch
import requests
from PIL import Image
device = torch.device('cuda')
# ("cuda" if torch.cuda.is_available() else "cpu")

# model = PerceiverForImageClassificationLearned.from_pretrained("deepmind/vision-perceiver-learned",
#                                                                num_labels=10,
#                                                                id2label=id2label,
#                                                                label2id=label2id,
#                                                                ignore_mismatched_sizes=True)


# EXAMPLE 2: using the Perceiver to classify images
# - we define an ImagePreprocessor, which can be used to embed images
config = PerceiverConfig(image_size=224,
                          use_labels=True,
                          num_labels=10,
                          id2label=id2label,
                          label2id=label2id,
                          ignore_mismatched_sizes=True
                         )
preprocessor = PerceiverImagePreprocessor(
    config,
    prep_type="conv1x1",
    spatial_downsample=1,
    out_channels=256,
    position_encoding_type="trainable",
    concat_or_add_pos="concat",
    project_pos_dim=256,
    trainable_position_encoding_kwargs=dict(
        num_channels=256,
        index_dims=config.image_size**2,
    ),
)

model = PerceiverModel(
    config,
    input_preprocessor=preprocessor,
    decoder=PerceiverClassificationDecoder(
        config,
        num_channels=config.d_latents,
        trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1),
        use_query_residual=True,
    ),
)
# .from_pretrained("deepmind/vision-perceiver-learned",
#                                                                num_labels=10,
#                                                                id2label=id2label,
#                                                                label2id=label2id,
#                                                                ignore_mismatched_sizes=True)

model.to(device)

PerceiverModel(
  (input_preprocessor): PerceiverImagePreprocessor(
    (convnet_1x1): Conv2d(3, 256, kernel_size=(1, 1), stride=(1, 1))
    (position_embeddings): PerceiverTrainablePositionEncoding()
    (positions_projection): Linear(in_features=256, out_features=256, bias=True)
    (conv_after_patches): Identity()
  )
  (embeddings): PerceiverEmbeddings()
  (encoder): PerceiverEncoder(
    (cross_attention): PerceiverLayer(
      (attention): PerceiverAttention(
        (self): PerceiverSelfAttention(
          (layernorm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (layernorm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (query): Linear(in_features=1280, out_features=512, bias=True)
          (key): Linear(in_features=512, out_features=512, bias=True)
          (value): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (output): PerceiverSelfOutput(
          (dense):

## Train the model

Here we train the model using native PyTorch.

In [13]:
from IPython.display import clear_output

from transformers import AdamW
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

# torch.cuda.empty_cache()
# print(torch.cuda.memory_summary(device=None, abbreviated=False))

optimizer = AdamW(model.parameters(), lr=5e-5)

model.train()
for epoch in range(10):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    for batch in tqdm(train_dataloader):
         # get the inputs; 
         inputs = batch["pixel_values"].to(device)
         labels = batch["labels"].to(device)

         # zero the parameter gradients
         optimizer.zero_grad()

         # forward + backward + optimize
         outputs = model(inputs=inputs)
         logits = outputs.logits

        #  print('logits: ', logits.shape)
        #  print('logits shape: ', logits)
        #  print('outputs: ', outputs)
        #  print('labels: ', labels)
         # [1, 2]

         # to train, one can train the model using standard cross-entropy:
         criterion = torch.nn.CrossEntropyLoss()

         # labels = torch.tensor([1,1])
         loss = criterion(logits, labels)
        #  print('loss: ', loss)
         #  print(outputs)


         
         loss.backward()
         optimizer.step()

        #  clear_output(wait=True)
        #  predictions = outputs.logits.argmax(-1).cpu().detach().numpy()
        #  accuracy = accuracy_score(y_true=batch["labels"].numpy(), y_pred=predictions)
        #  print(f"Loss: {loss.item()}, Accuracy: {accuracy}")

        #  # evaluate
    predictions = outputs.logits.argmax(-1).cpu().detach().numpy()
    accuracy = accuracy_score(y_true=batch["labels"].numpy(), y_pred=predictions)
    print(f"Loss: {loss.item()}, Accuracy: {accuracy}")
    
    

Epoch: 0




  0%|          | 0/90 [00:00<?, ?it/s]

Loss: 2.20711612701416, Accuracy: 0.1
Epoch: 1


  0%|          | 0/90 [00:00<?, ?it/s]

Loss: 2.7028424739837646, Accuracy: 0.0
Epoch: 2


  0%|          | 0/90 [00:00<?, ?it/s]

Loss: 1.9801855087280273, Accuracy: 0.4
Epoch: 3


  0%|          | 0/90 [00:00<?, ?it/s]

Loss: 1.9687011241912842, Accuracy: 0.5
Epoch: 4


  0%|          | 0/90 [00:00<?, ?it/s]

Loss: 2.119823694229126, Accuracy: 0.2
Epoch: 5


  0%|          | 0/90 [00:00<?, ?it/s]

Loss: 1.8747456073760986, Accuracy: 0.5
Epoch: 6


  0%|          | 0/90 [00:00<?, ?it/s]

Loss: 2.4442341327667236, Accuracy: 0.1
Epoch: 7


  0%|          | 0/90 [00:00<?, ?it/s]

Loss: 2.175804615020752, Accuracy: 0.2
Epoch: 8


  0%|          | 0/90 [00:00<?, ?it/s]

Loss: 2.0868992805480957, Accuracy: 0.5
Epoch: 9


  0%|          | 0/90 [00:00<?, ?it/s]

Loss: 1.8617359399795532, Accuracy: 0.4


## Evaluate the model

Finally, we evaluate the model on the test set. We use the Datasets library to compute the accuracy.

On some runs, I got 78%, then 66%. Of course, one would need to train on the entire dataset to achieve great results.

In [14]:
from tqdm.notebook import tqdm
from datasets import load_metric

accuracy = load_metric("accuracy")

model.eval()
for batch in tqdm(val_dataloader):
      # get the inputs; 
      inputs = batch["pixel_values"].to(device)
      labels = batch["labels"].to(device)

      # forward pass
      outputs = model(inputs=inputs)
      logits = outputs.logits 
      predictions = logits.argmax(-1).cpu().detach().numpy()
      references = batch["labels"].numpy()
      accuracy.add_batch(predictions=predictions, references=references)

final_score = accuracy.compute()
print("Accuracy on test set:", final_score)

  after removing the cwd from sys.path.


  0%|          | 0/10 [00:00<?, ?it/s]

Accuracy on test set: {'accuracy': 0.22}


### Overwriting ImagePreprocessor to Include Tagkop Encodings

In [17]:
from transformers import PerceiverConfig, PerceiverTokenizer, PerceiverFeatureExtractor, PerceiverModel

import torch
import requests
from PIL import Image
device = torch.device('mps')
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

from torch import nn
from transformers.models.perceiver.modeling_perceiver import (
    PerceiverTextPreprocessor,
    PerceiverImagePreprocessor,
    PerceiverClassificationDecoder,
    AbstractPreprocessor,
    PerceiverAbstractPositionEncoding
)

import torch

    
class PerceiverTagopPositionEncoding(PerceiverAbstractPositionEncoding):
    """Trainable position encoding."""

    def __init__(self, index_dims, num_channels=128):
        super().__init__()
        self._num_channels = num_channels
        self._index_dims = index_dims
        index_dim = np.prod(index_dims)
        self.position_embeddings = nn.Parameter(torch.randn(index_dim, num_channels))
        self.position_embeddings_previous = self.position_embeddings
    @property
    def num_dimensions(self) -> int:
        if isinstance(self._index_dims, int):
            return 1
        return len(self._index_dims)

    def output_size(self, *args, **kwargs) -> int:
        return self._num_channels

    def forward(self, batch_size: int) -> torch.Tensor:
        position_embeddings = self.position_embeddings
        
        print('equal?:::', torch.all(torch.eq(self.position_embeddings_previous, position_embeddings)))
        
        
        if batch_size is not None:
            position_embeddings = position_embeddings.expand(batch_size, -1, -1)
        
        print(position_embeddings)
        return position_embeddings
        
def build_position_encoding2(
    position_encoding_type,
    out_channels=None,
    project_pos_dim=-1,
    trainable_position_encoding_kwargs=None,
    fourier_position_encoding_kwargs=None,
    tagop_position_encoding_kwargs=None,
):
    """
    Builds the position encoding.
    Args:
    - out_channels: refers to the number of channels of the position encodings.
    - project_pos_dim: if specified, will project the position encodings to this dimension.
    """
#     raise ValueError("MERP!")
    if position_encoding_type == "trainable":
        if not trainable_position_encoding_kwargs:
            raise ValueError("Make sure to pass trainable_position_encoding_kwargs")
        output_pos_enc = PerceiverTrainablePositionEncoding(**trainable_position_encoding_kwargs)
    elif position_encoding_type == "fourier":
        # We don't use the index_dims argument, as this is only known during the forward pass
        if not fourier_position_encoding_kwargs:
            raise ValueError("Make sure to pass fourier_position_encoding_kwargs")
        output_pos_enc = PerceiverFourierPositionEncoding(**fourier_position_encoding_kwargs)
    elif position_encoding_type == "tagop":
        if not tagop_position_encoding_kwargs:
            raise ValueError("Make sure to pass tagop_position_encoding_kwargs")
        output_pos_enc = PerceiverTagopPositionEncoding(**tagop_position_encoding_kwargs)
    else:
        raise ValueError(f"Unknown position encoding type: {position_encoding_type}.")

    # Optionally, project the position encoding to a target dimension:
    positions_projection = nn.Linear(out_channels, project_pos_dim) if project_pos_dim > 0 else nn.Identity()

    return output_pos_enc, positions_projection

class PerceiverImagePreprocessor2(AbstractPreprocessor):
    """
    Image preprocessing for Perceiver Encoder.
    Note: the *out_channels* argument refers to the output channels of a convolutional layer, if *prep_type* is set to
    "conv1x1" or "conv". If one adds absolute position embeddings, one must make sure the *num_channels* of the
    position encoding kwargs are set equal to the *out_channels*.
    Args:
        config ([*PerceiverConfig*]):
            Model configuration.
        prep_type (`str`, *optional*, defaults to `"conv"`):
            Preprocessing type. Can be "conv1x1", "conv", "patches", "pixels".
        spatial_downsample (`int`, *optional*, defaults to 4):
            Spatial downsampling factor.
        temporal_downsample (`int`, *optional*, defaults to 1):
            Temporal downsampling factor (only relevant in case a time dimension is present).
        position_encoding_type (`str`, *optional*, defaults to `"fourier"`):
            Position encoding type. Can be "fourier" or "trainable".
        in_channels (`int`, *optional*, defaults to 3):
            Number of channels in the input.
        out_channels (`int`, *optional*, defaults to 64):
            Number of channels in the output.
        conv_after_patching (`bool`, *optional*, defaults to `False`):
            Whether to apply a convolutional layer after patching.
        conv_after_patching_in_channels (`int`, *optional*, defaults to 54):
            Number of channels in the input of the convolutional layer after patching.
        conv2d_use_batchnorm (`bool`, *optional*, defaults to `True`):
            Whether to use batch normalization in the convolutional layer.
        concat_or_add_pos (`str`, *optional*, defaults to `"concat"`):
            How to concatenate the position encoding to the input. Can be "concat" or "add".
        project_pos_dim (`int`, *optional*, defaults to -1):
            Dimension of the position encoding to project to. If -1, no projection is applied.
        **position_encoding_kwargs (`Dict`, *optional*):
            Keyword arguments for the position encoding.
    """

    def __init__(
        self,
        config,
        prep_type="conv",
        spatial_downsample: int = 4,
        temporal_downsample: int = 1,
        position_encoding_type: str = "fourier",
        in_channels: int = 3,
        out_channels: int = 64,
        conv_after_patching: bool = False,
        conv_after_patching_in_channels: int = 54,  # only relevant when conv_after_patching = True
        conv2d_use_batchnorm: bool = True,
        concat_or_add_pos: str = "concat",
        project_pos_dim: int = -1,
        **position_encoding_kwargs,
    ):
        super().__init__()
        self.config = config

        if prep_type not in ("conv", "patches", "pixels", "conv1x1"):
            raise ValueError(f"Prep_type {prep_type} is invalid")

        if concat_or_add_pos not in ["concat", "add"]:
            raise ValueError(f"Invalid value {concat_or_add_pos} for concat_or_add_pos.")

        self.in_channels = in_channels
        self.prep_type = prep_type
        self.spatial_downsample = spatial_downsample
        self.temporal_downsample = temporal_downsample
        self.position_encoding_type = position_encoding_type
        self.concat_or_add_pos = concat_or_add_pos
        self.conv_after_patching = conv_after_patching
        self.out_channels = out_channels

        if self.prep_type == "conv":
            # Downsampling with conv is currently restricted
            convnet_num_layers = math.log(spatial_downsample, 4)
            convnet_num_layers_is_int = convnet_num_layers == np.round(convnet_num_layers)
            if not convnet_num_layers_is_int or temporal_downsample != 1:
                raise ValueError(
                    "Only powers of 4 expected for spatial and 1 expected for temporal downsampling with conv."
                )
            self.convnet = Conv2DDownsample(
                in_channels=in_channels,
                num_layers=int(convnet_num_layers),
                out_channels=out_channels,
                use_batchnorm=conv2d_use_batchnorm,
            )

        elif self.prep_type == "conv1x1":
            if temporal_downsample != 1:
                raise ValueError("Conv1x1 does not downsample in time.")
            self.convnet_1x1 = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=(1, 1),
                # spatial_downsample is unconstrained for 1x1 convolutions.
                stride=(spatial_downsample, spatial_downsample),
            )

        # Position embeddings
        self.project_pos_dim = project_pos_dim
        self.position_embeddings, self.positions_projection = build_position_encoding2(
            position_encoding_type=position_encoding_type,
            out_channels=out_channels,
            project_pos_dim=project_pos_dim,
            **position_encoding_kwargs,
        )

        # Optional convolutional layer after patches.
        self.conv_after_patches = (
            nn.Linear(conv_after_patching_in_channels, self.out_channels) if conv_after_patching else nn.Identity()
        )

    @property
    def num_channels(self) -> int:
        # Let's assume that the number of resolutions (in the context of image preprocessing)
        # of the input data is 2 or 3 depending on whether we are processing image or video respectively.
        # In this case, for convenience, we will declare is_temporal variable,
        # which will show whether the data has a temporal dimension or not.
        is_temporal = self.position_embeddings.num_dimensions > 2

        # position embedding
        if self.project_pos_dim > 0:
            pos_dim = self.project_pos_dim
        else:
            pos_dim = self.position_embeddings.output_size()
        if self.concat_or_add_pos == "add":
            return pos_dim

        # inputs
        if self.conv_after_patching or self.prep_type in ("conv1x1", "conv"):
            inp_dim = self.out_channels
        elif self.prep_type == "pixels":
            inp_dim = self.in_channels
            if not is_temporal:
                inp_dim = math.ceil(inp_dim / self.spatial_downsample)
        elif self.prep_type == "patches":
            if self.conv_after_patching:
                inp_dim = self.out_channels
            else:
                inp_dim = self.in_channels * self.spatial_downsample**2
                if is_temporal:
                    inp_dim *= self.temporal_downsample

        return inp_dim + pos_dim

    def _build_network_inputs(self, inputs: torch.Tensor, network_input_is_1d: bool = True):
        """
        Construct the final input, including position encoding.
        This method expects the inputs to always have channels as last dimension.
        """
        batch_size = inputs.shape[0]
        index_dims = inputs.shape[1:-1]
        indices = np.prod(index_dims)

        # Flatten input features to a 1D index dimension if necessary.
        if len(inputs.shape) > 3 and network_input_is_1d:
            inputs = torch.reshape(inputs, [batch_size, indices, -1])

        # Construct the position encoding.
        if self.position_encoding_type == "trainable":
            pos_enc = self.position_embeddings(batch_size)
        elif self.position_encoding_type == "tagop":
            pos_enc = self.position_embeddings(batch_size)
        elif self.position_encoding_type == "fourier":
            pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device)

        # Optionally project them to a target dimension.
        pos_enc = self.positions_projection(pos_enc)

        if not network_input_is_1d:
            # Reshape pos to match the input feature shape
            # if the network takes non-1D inputs
            sh = inputs.shape
            pos_enc = torch.reshape(pos_enc, list(sh)[:-1] + [-1])
        if self.concat_or_add_pos == "concat":
            inputs_with_pos = torch.cat([inputs, pos_enc], dim=-1)
        elif self.concat_or_add_pos == "add":
            inputs_with_pos = inputs + pos_enc
        return inputs_with_pos, inputs

    def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
        if self.prep_type == "conv":
            # Convnet image featurization.
            # Downsamples spatially by a factor of 4
            inputs = self.convnet(inputs)

        elif self.prep_type == "conv1x1":
            # map inputs to self.out_channels
            inputs = self.convnet_1x1(inputs)

        elif self.prep_type == "pixels":
            # if requested, downsamples in the crudest way
            if inputs.ndim == 4:
                inputs = inputs[:: self.spatial_downsample, :: self.spatial_downsample]
            elif inputs.ndim == 5:
                inputs = inputs[
                    :, :: self.temporal_downsample, :, :: self.spatial_downsample, :: self.spatial_downsample
                ]
            else:
                raise ValueError("Unsupported data format for pixels.")

        elif self.prep_type == "patches":
            # Space2depth featurization.
            # Video: B x T x C x H x W
            inputs = space_to_depth(
                inputs, temporal_block_size=self.temporal_downsample, spatial_block_size=self.spatial_downsample
            )

            if inputs.ndim == 5 and inputs.shape[1] == 1:
                # for flow
                inputs = inputs.squeeze(dim=1)

            # Optionally apply conv layer.
            inputs = self.conv_after_patches(inputs)

        if self.prep_type != "patches":
            # move channels to last dimension, as the _build_network_inputs method below expects this
            if inputs.ndim == 4:
                inputs = torch.permute(inputs, (0, 2, 3, 1))
            elif inputs.ndim == 5:
                inputs = torch.permute(inputs, (0, 1, 3, 4, 2))
            else:
                raise ValueError("Unsupported data format for conv1x1.")

        inputs, inputs_without_pos = self._build_network_inputs(inputs, network_input_is_1d)
        modality_sizes = None  # Size for each modality, only needed for multimodal

        return inputs, modality_sizes, inputs_without_pos
    

#### Define Model

In [18]:
from transformers import PerceiverConfig, PerceiverTokenizer, PerceiverFeatureExtractor, PerceiverModel
from transformers.models.perceiver.modeling_perceiver import (
    PerceiverTextPreprocessor,
    PerceiverImagePreprocessor,
    PerceiverClassificationDecoder,
)
import torch
import requests
from PIL import Image
device = torch.device('mps')
# ("cuda" if torch.cuda.is_available() else "cpu")

# model = PerceiverForImageClassificationLearned.from_pretrained("deepmind/vision-perceiver-learned",
#                                                                num_labels=10,
#                                                                id2label=id2label,
#                                                                label2id=label2id,
#                                                                ignore_mismatched_sizes=True)


# EXAMPLE 2: using the Perceiver to classify images
# - we define an ImagePreprocessor, which can be used to embed images
config = PerceiverConfig(image_size=224,
                          num_self_attends_per_block = 1,
                         num_cross_attention_heads = 1,
                          use_labels=True,
                          num_labels=10,
                         num_latents=10,
                          id2label=id2label,
                          label2id=label2id,
                          ignore_mismatched_sizes=True
                         )

preprocessor = PerceiverImagePreprocessor2(
    config,
    prep_type="conv1x1",
    spatial_downsample=1,
    out_channels=256,
    position_encoding_type="tagop",
    concat_or_add_pos="concat",
    project_pos_dim=256,
    tagop_position_encoding_kwargs=dict(
        num_channels=256,
        index_dims=config.image_size**2,
    ),
)

modely = PerceiverModel(
    config,
    input_preprocessor=preprocessor,
    decoder=PerceiverClassificationDecoder(
        config,
        num_channels=config.d_latents,
        trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1),
        use_query_residual=True,
    ),
)

modely.to(device)

print('model: ', modely)

model:  PerceiverModel(
  (input_preprocessor): PerceiverImagePreprocessor2(
    (convnet_1x1): Conv2d(3, 256, kernel_size=(1, 1), stride=(1, 1))
    (position_embeddings): PerceiverTagopPositionEncoding()
    (positions_projection): Linear(in_features=256, out_features=256, bias=True)
    (conv_after_patches): Identity()
  )
  (embeddings): PerceiverEmbeddings()
  (encoder): PerceiverEncoder(
    (cross_attention): PerceiverLayer(
      (attention): PerceiverAttention(
        (self): PerceiverSelfAttention(
          (layernorm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (layernorm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (query): Linear(in_features=1280, out_features=512, bias=True)
          (key): Linear(in_features=512, out_features=512, bias=True)
          (value): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (output): PerceiverSelfOutput(
          (de

#### Train Model

In [None]:
from IPython.display import clear_output

from transformers import AdamW
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

# torch.cuda.empty_cache()
# print(torch.cuda.memory_summary(device=None, abbreviated=False))

optimizer = AdamW(modely.parameters(), lr=5e-5)

modely.train()
for epoch in range(10):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    for batch in tqdm(train_dataloader):
         # get the inputs; 
         inputs = batch["pixel_values"].to(device)
         labels = batch["labels"].to(device)

         # zero the parameter gradients
         optimizer.zero_grad()

         # forward + backward + optimize
         outputs = modely(inputs=inputs)
         logits = outputs.logits

        #  print('logits: ', logits.shape)
        #  print('logits shape: ', logits)
        #  print('outputs: ', outputs)
        #  print('labels: ', labels)
         # [1, 2]

         # to train, one can train the model using standard cross-entropy:
         criterion = torch.nn.CrossEntropyLoss()

         # labels = torch.tensor([1,1])
         loss = criterion(logits, labels)
        #  print('loss: ', loss)
         #  print(outputs)


         
         loss.backward()
         optimizer.step()

#          clear_output(wait=True)
         predictions = outputs.logits.argmax(-1).cpu().detach().numpy()
         accuracy = accuracy_score(y_true=batch["labels"].numpy(), y_pred=predictions)
         print(f"Loss: {loss.item()}, Accuracy: {accuracy}")

        #  # evaluate
    predictions = outputs.logits.argmax(-1).cpu().detach().numpy()
    accuracy = accuracy_score(y_true=batch["labels"].numpy(), y_pred=predictions)
    print(f"Loss: {loss.item()}, Accuracy: {accuracy}")
    
    

Epoch: 0




  0%|          | 0/90 [00:00<?, ?it/s]

equal?::: tensor(True, device='mps:0')


  nonzero_finite_vals = torch.masked_select(


tensor([[[-0.0244, -0.3288, -0.7838,  ..., -0.9368,  1.0210,  1.2446],
         [ 0.9389,  0.0347,  0.6975,  ...,  0.5370,  0.4642, -0.0948],
         [ 1.0442,  0.5288, -0.9988,  ...,  0.9579,  0.4188,  0.7296],
         ...,
         [ 1.3882,  1.0510,  1.0907,  ...,  0.8141, -0.2886, -0.3021],
         [ 0.7254,  0.8176,  0.2260,  ...,  0.6200,  0.5051, -1.0945],
         [-0.7169, -0.2190, -1.8206,  ..., -1.6109,  1.3579,  0.0768]],

        [[-0.0244, -0.3288, -0.7838,  ..., -0.9368,  1.0210,  1.2446],
         [ 0.9389,  0.0347,  0.6975,  ...,  0.5370,  0.4642, -0.0948],
         [ 1.0442,  0.5288, -0.9988,  ...,  0.9579,  0.4188,  0.7296],
         ...,
         [ 1.3882,  1.0510,  1.0907,  ...,  0.8141, -0.2886, -0.3021],
         [ 0.7254,  0.8176,  0.2260,  ...,  0.6200,  0.5051, -1.0945],
         [-0.7169, -0.2190, -1.8206,  ..., -1.6109,  1.3579,  0.0768]],

        [[-0.0244, -0.3288, -0.7838,  ..., -0.9368,  1.0210,  1.2446],
         [ 0.9389,  0.0347,  0.6975,  ...,  0