# PerceiverIO classification model

An example of image classification.

```
@misc{jaegle2021perceiver,
      title={Perceiver IO: A General Architecture for Structured Inputs & Outputs}, 
      author={Andrew Jaegle and Sebastian Borgeaud and Jean-Baptiste Alayrac and Carl Doersch and Catalin Ionescu and David Ding and Skanda Koppula and Daniel Zoran and Andrew Brock and Evan Shelhamer and Olivier Hénaff and Matthew M. Botvinick and Andrew Zisserman and Oriol Vinyals and João Carreira},
      year={2021},
      eprint={2107.14795},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
```

In [None]:
%%bash

## Data preprocessing.
pip install -U imageio opencv-python

## Haiku is used to convert weights of the original model.
pip install -U dm-haiku

In [2]:
from functools import partial
import jax
from jax import numpy as jnp
from flax import linen as nn
from flax_extra import random
from flax_extra.operator import ReshapeBatch
from flax_extra.layer import (
    KVQAttention,
    Encoding,
    Decoding,
    FourierPositionEncoding,
    TrainablePositionalEncoding,
)
from flax_extra.layer.io import (
    input_encoding,
    query_encoding,
    output_decoding,
)
from flax_extra.model import PerceiverIO
from util.data import (
    load_image,
    normalize,
    resize_and_center_crop,
    LABELS,
)
from util.original_model import variables

In [3]:
model = PerceiverIO(
    input_encoding=input_encoding(
        Encoding,
        positional_encoding=partial(
            FourierPositionEncoding,
            seqshape=(224, 224),
            n_bands=64,
        ),
    ),
    encoder_query_encoding=query_encoding(
        TrainablePositionalEncoding,
        seqlen=512,
        dimension=1024,
    ),
    decoder_query_encoding=query_encoding(
        TrainablePositionalEncoding,
        seqlen=1,
        dimension=1024,
    ),
    output_decoding=output_decoding(
        Decoding,
        embedding_decoding=partial(
            nn.Dense,
            features=1000,
        ),
        postprocessing=partial(
            ReshapeBatch,
            shape=(-1,)
        ),
    ),
    encoder_attention=partial(KVQAttention, n_heads=1),
    decoder_attention=partial(KVQAttention, n_heads=1),
    use_decoder_q_residual=True,
)
model_init = model.init
model_apply = model.apply

In [None]:
%%bash

wget -cO "/tmp/perceiver_classification_fourier_position_encoding.pystate" "https://storage.googleapis.com/perceiver_io/imagenet_fourier_position_encoding.pystate"
wget -cO "/tmp/perceiver_classification_image_example.jpg" "https://storage.googleapis.com/perceiver_io/dalmation.jpg"

In [5]:
rng = random.sequence(seed=0)
collections = ["params"]

initial_variables = variables("/tmp/perceiver_classification_fourier_position_encoding.pystate")

image = load_image("/tmp/perceiver_classification_image_example.jpg")
centered_image = resize_and_center_crop(image)
inputs = normalize(centered_image)[None]

outputs = model_apply(
    initial_variables,
    inputs=inputs,
    rngs=random.into_collection(key=next(rng), labels=collections),
)

_, indices = jax.lax.top_k(outputs[0], 5)
probs = jax.nn.softmax(outputs[0])
print('Top 5 labels:')
for i in list(indices):
    print(f'{LABELS[i]}: {probs[i]}')



Top 5 labels:
dalmatian, coach dog, carriage dog: 0.8736159801483154
Great Dane: 0.01089583057910204
English setter: 0.002538368571549654
muzzle: 0.0010286346077919006
American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier: 0.0007839840836822987
