# PerceiverIO autoencoder model

An example of multimodal (video, image, and label) autoencoding.

```
@misc{jaegle2021perceiver,
      title={Perceiver IO: A General Architecture for Structured Inputs & Outputs}, 
      author={Andrew Jaegle and Sebastian Borgeaud and Jean-Baptiste Alayrac and Carl Doersch and Catalin Ionescu and David Ding and Skanda Koppula and Daniel Zoran and Andrew Brock and Evan Shelhamer and Olivier Hénaff and Matthew M. Botvinick and Andrew Zisserman and Oriol Vinyals and João Carreira},
      year={2021},
      eprint={2107.14795},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
```

In [None]:
%%bash

## Data preprocessing.
pip install -U imageio opencv-python scipy

## Haiku is used to convert weights of the original model.
pip install -U dm-haiku

In [2]:
from functools import partial
import jax
from jax import numpy as jnp
from flax import linen as nn
from flax_extra import random
from flax_extra import operator as op
from flax_extra.layer import (
    KVQAttention,
    Encoding,
    Decoding,
    TrainablePositionalEncoding,
    FourierPositionEncoding,
)
from flax_extra.layer.io import (
    input_encoding,
    query_encoding,
    output_decoding,
)
from flax_extra.model import PerceiverIO
from util.data import (
    load_audio,
    load_video,
    output_positions,
    LABELS,
)
from util.original_model import variables

Array = jnp.ndarray

In [3]:
N_VIDEO_FRAMES = 16                            # per example
N_AUDIO_FRAMES = 48000 // 25 * N_VIDEO_FRAMES  # per example
N_PACKED_AUDIO_FRAMES = 16  # per sequence position
                            # audio is grouped in packets
VIDEO_PATCH_SIZE = 56       # per sequence position
                            # video is splitted into patches
N_CHUNKS = 128
N_CLASSES = 700

## Audio.
AudioPositionalEncoding = partial(
    FourierPositionEncoding,
    seqshape=(N_AUDIO_FRAMES // N_PACKED_AUDIO_FRAMES,),
    n_bands=192,
)
AudioEncoding = partial(
    Encoding,
    preprocessing=partial(
        op.Rearrange,
        pattern="b (t dt) dc -> b t (dt dc)",
        bindings=dict(dt=N_PACKED_AUDIO_FRAMES),
    ),
    positional_encoding=AudioPositionalEncoding,
)
AudioDecoding = partial(
    Decoding,
    embedding_decoding=partial(
        nn.Dense,
        features=N_PACKED_AUDIO_FRAMES,
    ),
    postprocessing=partial(
        op.ReshapeBatch,
        shape=(-1,)
    ),
)

## Video.
VideoPositionalEncoding = partial(
    FourierPositionEncoding,
    seqshape=(N_VIDEO_FRAMES, VIDEO_PATCH_SIZE, VIDEO_PATCH_SIZE),
    n_bands=32,
)
VideoEncoding = partial(
    Encoding,
    preprocessing=partial(
        op.Rearrange,
        pattern="b (t dt) (h dh) (w dw) dc -> b t h w (dt dh dw dc)",
        bindings=dict(dt=1, dh=4, dw=4),
    ),
    positional_encoding=VideoPositionalEncoding,
)
VideoDecoding = partial(
    Decoding,
    embedding_decoding=partial(
        nn.Dense,
        features=3,
    ),
)

## Labels.
LabelPositionalEncoding = partial(
    TrainablePositionalEncoding,
    seqlen=1,
    dimension=1024,
)
LabelEncoding = partial(
    Encoding,
    preprocessing=partial(
        op.Rearrange,
        pattern="b dc -> b 1 dc",
        bindings=dict(),
    ),
)
LabelDecoding = partial(
    Decoding,
    embedding_decoding=partial(
        nn.Dense,
        features=N_CLASSES,
    ),
    postprocessing=partial(
        op.ReshapeBatch,
        shape=(-1,)
    ),
)

model = PerceiverIO(
    input_encoding=input_encoding(
        AudioEncoding,
        VideoEncoding,
        LabelEncoding,
        mask_rates=[0., 0., 1.],
        d_reserved=4,
    ),
    encoder_query_encoding=query_encoding(
        TrainablePositionalEncoding,
        seqlen=784,
        dimension=512,
    ),
    decoder_query_encoding=query_encoding(
        AudioPositionalEncoding,
        VideoPositionalEncoding,
        LabelPositionalEncoding,
        d_reserved=2,
    ),
    output_decoding=output_decoding(
        AudioDecoding,
        VideoDecoding,
        LabelDecoding,
        multimodal_embedding_decoding=partial(
            nn.Dense,
            features=512,
        ),
    ),
    n_processor_shards=1,
    n_processor_blocks=8,
    encoder_attention=partial(KVQAttention, n_heads=1),
    decoder_attention=partial(KVQAttention, n_heads=1),
)
model_init = model.init
model_apply = model.apply

In [None]:
%%bash

wget -cO "/tmp/perceiver_autoencoding.pickle" "https://storage.googleapis.com/perceiver_io/video_autoencoding_checkpoint.pystate"
wget --check-certificate=quiet -cO "/tmp/perceiver_autoencoding_video_example.avi" "https://www.crcv.ucf.edu/THUMOS14/UCF101/UCF101/v_ApplyEyeMakeup_g01_c01.avi"
yes | ffmpeg -i "/tmp/perceiver_autoencoding_video_example.avi" -c copy  -f wav -map 0:a pcm_f32le -ar 48000 "/tmp/perceiver_autoencoding_audio_example.wav"


In [7]:
def autoencode(variables, inputs, rng):
    audio_outputs = None
    video_outputs = None
    label_outputs = None
    audio_inputs, video_inputs, label_inputs = inputs
    for chunk_index in range(N_CHUNKS):
        audio_chunk, video_chunk, label_outputs = model_apply(
            variables,
            inputs=inputs,
            output_positions=[
                output_positions(
                    chunk_shape=audio_inputs.shape,
                    chunk_index=chunk_index,
                    n_chunks=N_CHUNKS,
                    n_frames=N_PACKED_AUDIO_FRAMES,
                ),
                output_positions(
                    chunk_shape=video_inputs.shape,
                    chunk_index=chunk_index,
                    n_chunks=N_CHUNKS,
                ),
                None,
            ],
            rngs=random.into_collection(key=next(rng), labels=collections),
        )

        if audio_outputs is None:
            audio_outputs = audio_chunk
        else:
            audio_outputs = jnp.concatenate([audio_outputs, audio_chunk], axis=1)

        if video_outputs is None:
            video_outputs = video_chunk
        else:
            video_outputs = jnp.concatenate([video_outputs, video_chunk], axis=1)

    audio_outputs = jnp.reshape(audio_outputs, audio_inputs.shape)
    video_outputs = jnp.reshape(video_outputs, video_inputs.shape)
    return [audio_outputs, video_outputs, label_outputs]

rng = random.sequence(seed=0)
collections = ["params"]
initial_variables = variables("/tmp/perceiver_autoencoding.pickle")

audio = load_audio("/tmp/perceiver_autoencoding_audio_example.wav")
video = load_video("/tmp/perceiver_autoencoding_video_example.avi")
inputs = [
    audio[None, :N_AUDIO_FRAMES, :1],
    video[None, :N_VIDEO_FRAMES],
    jnp.zeros((1, N_CLASSES)),
]

audio_outputs, video_outputs, label_outputs = autoencode(initial_variables, inputs, rng)

# Kinetics 700 Labels.
scores, indices = jax.lax.top_k(jax.nn.softmax(label_outputs), 5)
for score, index in zip(scores[0], indices[0]):
    print("%s: %s" % (LABELS[index], score))



trimming or shaving beard: 0.21497257
dyeing hair: 0.19800161
raising eyebrows: 0.09644758
winking: 0.09643903
playing harmonica: 0.083919466
