In [3]:
import torch
import numpy as np
from transformers import PerceiverFeatureExtractor, PerceiverTokenizer, PerceiverForMultimodalAutoencoding
from datasets import load_dataset

In [2]:
# create multimodal inputs
images = torch.randn((1, 16, 3, 224, 224))
audio = torch.randn((1, 30720, 1))
inputs = dict(image=images, audio=audio, label=torch.zeros((images.shape[0], 700)))

In [4]:
wikipedia = load_dataset("wikipedia", "20220301.en", cache_dir="E:/Datasets/")

Downloading builder script: 35.9kB [00:00, 8.97MB/s]                   
Downloading metadata: 30.4kB [00:00, 3.04MB/s]                   
Reusing dataset wikipedia (E:/Datasets/wikipedia\20220301.en\2.0.0\aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)
100%|██████████| 1/1 [03:09<00:00, 189.21s/it]


In [10]:
len(wikipedia['train'])

6458670

In [11]:
model = PerceiverForMultimodalAutoencoding.from_pretrained("deepmind/multimodal-perceiver")

In [14]:
# in the Perceiver IO paper, videos are auto-encoded in chunks
# each chunk subsamples different index dimensions of the image and audio modality decoder queries
nchunks = 128
image_chunk_size = np.prod((16, 224, 224)) // nchunks
audio_chunk_size = audio.shape[1] // model.config.samples_per_patch // nchunks
# process the first chunk
chunk_idx = 0
subsampling = {
    "image": torch.arange(image_chunk_size * chunk_idx, image_chunk_size * (chunk_idx + 1)),
    "audio": torch.arange(audio_chunk_size * chunk_idx, audio_chunk_size * (chunk_idx + 1)),
    "label": None,
}

In [20]:
outputs = model(inputs=inputs, subsampled_output_points=subsampling, output_hidden_states=True)

In [63]:
model.config.d_latents

512

In [91]:
y = outputs.hidden_states

In [51]:
torch.nn.functional.instance_norm(y).size()

torch.Size([1, 784, 512])

In [92]:
regression_head = torch.nn.Sequential(
    torch.nn.Linear(model.config.d_latents, model.config.d_latents * 2), 
    torch.nn.GELU(), 
    torch.nn.Linear(model.config.d_latents * 2, model.config.d_latents)
    )

In [93]:
# https://www.baeldung.com/cs/instance-vs-batch-normalization


with torch.no_grad():
    y = y[-5:]  # take the last k transformer layers
    # Follow the same layer normalization procedure for text and vision
    y = [torch.layer_norm(tl.float(), tl.shape[-1:]) for tl in y]
    y = sum(y) / len(y)
    if True: # noralize targets
        y = torch.layer_norm(y.float(), y.shape[-1:])

    # # # Use instance normalization for audio
    # y = [torch.nn.functional.instance_norm(tl.float()) for tl in y]
    # y = sum(y) / len(y)
    # if True: # normalize targets
    #     y = torch.nn.functional.instance_norm(y.transpose(1, 2)).transpose(1, 2)

print(outputs.hidden_states[-1:][0].size())
x = regression_head(outputs.hidden_states[-1:][0])
print(x.size())
print(y.size())

torch.Size([1, 784, 512])
torch.Size([1, 784, 512])
torch.Size([1, 784, 512])


In [83]:
l1 = torch.nn.SmoothL1Loss()
mse = torch.nn.MSELoss()

In [88]:
l1(x, y)

tensor(4.0948, grad_fn=<SmoothL1LossBackward0>)

In [89]:
l1(x.float(), y.float()).sum(dim=-1).sum().div(x.size(0))

tensor(4.0948, grad_fn=<DivBackward0>)

In [90]:
mse(x, y)

tensor(38.8993, grad_fn=<MseLossBackward0>)

In [35]:
preprocessor = model.perceiver.input_preprocessor

In [48]:
preprocessor.modalities["image"](images)[0].size()

torch.Size([1, 50176, 243])

In [49]:
preprocessor.modalities["audio"](audio)[0].size()

torch.Size([1, 1920, 401])

In [64]:
inputs_prcessed, modality_sizes, inputs_without_pos = preprocessor({"image": images, "audio": audio, "label": torch.zeros((images.shape[0], 700))})