In [None]:
import torch
import numpy as np
from transformers import PerceiverFeatureExtractor, PerceiverTokenizer, PerceiverForMultimodalAutoencoding
from datasets import load_dataset

In [3]:
# create multimodal inputs
images = torch.randn((1, 16, 3, 224, 224))
audio = torch.randn((1, 30720, 1))
inputs = dict(image=images, audio=audio, label=torch.zeros((images.shape[0], 700)))

In [None]:
wikipedia = load_dataset("wikipedia", "20220301.en", cache_dir="E:/Datasets/")

In [10]:
len(wikipedia['train'])

6458670

In [4]:
model = PerceiverForMultimodalAutoencoding.from_pretrained("deepmind/multimodal-perceiver")

In [5]:
# in the Perceiver IO paper, videos are auto-encoded in chunks
# each chunk subsamples different index dimensions of the image and audio modality decoder queries
nchunks = 128
image_chunk_size = np.prod((16, 224, 224)) // nchunks
audio_chunk_size = audio.shape[1] // model.config.samples_per_patch // nchunks
# process the first chunk
chunk_idx = 0
subsampling = {
    "image": torch.arange(image_chunk_size * chunk_idx, image_chunk_size * (chunk_idx + 1)),
    "audio": torch.arange(audio_chunk_size * chunk_idx, audio_chunk_size * (chunk_idx + 1)),
    "label": None,
}

In [None]:
model.config

In [14]:
inputs, _, _ = model.perceiver.input_preprocessor(inputs, subsampling)

In [16]:
print(inputs.size())

torch.Size([1, 52097, 704])


In [18]:
batch_size, seq_length, _ = inputs.size()

In [19]:
latent_array = model.perceiver.embeddings(batch_size=batch_size)

In [20]:
latent_array.size()

torch.Size([1, 784, 512])

In [21]:
encoded = model.perceiver.encoder(hidden_states=latent_array, inputs=inputs)

In [25]:
encoded.last_hidden_state.size()

torch.Size([1, 784, 512])

In [20]:
outputs = model(inputs=inputs, subsampled_output_points=subsampling, output_hidden_states=True)

In [63]:
model.config.d_latents

512

In [None]:
y = outputs.hidden_states

In [51]:
torch.nn.functional.instance_norm(y).size()

torch.Size([1, 784, 512])

In [35]:
preprocessor = model.perceiver.input_preprocessor

In [48]:
preprocessor.modalities["image"](images)[0].size()

torch.Size([1, 50176, 243])

In [49]:
preprocessor.modalities["audio"](audio)[0].size()

torch.Size([1, 1920, 401])

In [64]:
inputs_prcessed, modality_sizes, inputs_without_pos = preprocessor({"image": images, "audio": audio, "label": torch.zeros((images.shape[0], 700))})