In [1]:
from datasets import load_dataset
from transformers import PerceiverTokenizer
from torch.utils.data import DataLoader
import torch
device = torch.device("cpu")

train_ds, test_ds = load_dataset("imdb", split=['train', 'test'])
# Split up training into training + validation
splits = train_ds.train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']

tokenizer = PerceiverTokenizer.from_pretrained("deepmind/language-perceiver")
train_ds = train_ds.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True),
                        batched=True)
val_ds = val_ds.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True),
                    batched=True)
test_ds = test_ds.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True),
                        batched=True)

train_ds.set_format(type="torch", columns=['input_ids', 'label'])
val_ds.set_format(type="torch", columns=['input_ids', 'label'])
test_ds.set_format(type="torch", columns=['input_ids', 'label'])

train_batch_size = 10
eval_batch_size = 10

train_dataloader = DataLoader(train_ds, shuffle=True, batch_size=train_batch_size)
val_dataloader = DataLoader(val_ds, batch_size=eval_batch_size)
test_dataloader = DataLoader(test_ds, batch_size=eval_batch_size)

Found cached dataset imdb (/Users/stefanbroecker/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)


  0%|          | 0/2 [00:00<?, ?it/s]

Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.


  0%|          | 0/23 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

Loading cached processed dataset at /Users/stefanbroecker/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-bcb9ca5bd890d68e.arrow


In [2]:
import numpy as np

# t = next(iter(train_dataloader))["pixel_values"]

# torch.mean(t, dim=1).unsqueeze(1).shape

# torch.flatten(torch.mean(t, dim=1).unsqueeze(1), start_dim=2).shape

In [2]:
# preprocessor we customized to use the tagkop encoder
from tagkop_encoding_functions import (
    PerceiverImagePreprocessor
)
# perceiver modules from hugging face
from transformers.models.perceiver.modeling_perceiver import (
    PerceiverConfig,
    PerceiverModel,
    PerceiverClassificationDecoder
)
%load_ext autoreload
%autoreload 2

In [3]:
config = PerceiverConfig(image_size=2048,
                        num_self_attends_per_block = 1,
                        num_cross_attention_heads = 1,
                        use_labels=True,
                        num_labels=2,
                        num_latents=10,
                        ignore_mismatched_sizes=True
                        )

preprocessor = PerceiverImagePreprocessor(
    config,
    in_channels=1,
    prep_type="1d",
    spatial_downsample=1,
    out_channels=64,
    position_encoding_type="fourier",
    concat_or_add_pos="concat",
    project_pos_dim=64,
    tagkop_position_encoding_kwargs=dict(
        num_channels=64,
        index_dims=config.image_size**2,
        ds="imdb"
    ),
    trainable_position_encoding_kwargs=dict(
        num_channels=64,
        index_dims=2048,
    )
)

modely = PerceiverModel(
    config,
    input_preprocessor=preprocessor,
    decoder=PerceiverClassificationDecoder(
        config,
        num_channels=config.d_latents,
        trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=2048),
        use_query_residual=True,
    ),
)

modely.to(device)

print('model: ', modely)

ValueError: Make sure to pass fourier_position_encoding_kwargs

In [5]:
from torch.optim import AdamW
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

optimizer = AdamW(modely.parameters(), lr=5e-5)

modely.train()
for epoch in range(10):
    for batch in tqdm(train_dataloader):
        # get the inputs; 
        inputs = batch["input_ids"].to(device)
        labels = batch["label"].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = modely(inputs=inputs.unsqueeze(1))
        logits = outputs.logits

        criterion = torch.nn.CrossEntropyLoss()

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        predictions = outputs.logits.argmax(-1).cpu().detach().numpy()
        accuracy = accuracy_score(y_true=batch["labels"].numpy(), y_pred=predictions)
        print(f"Loss: {loss.item()}, Accuracy: {accuracy}")

    predictions = outputs.logits.argmax(-1).cpu().detach().numpy()
    accuracy = accuracy_score(y_true=batch["labels"].numpy(), y_pred=predictions)
    print(f"Loss: {loss.item()}, Accuracy: {accuracy}")
    
    

  0%|          | 0/2250 [00:00<?, ?it/s]

using imdb dataset


In [42]:
imdb_encodings = torch.from_numpy(torch.load("pos_embeddings_IMDB_tSNE_2048x64.pth"))
(imdb_encodings - torch.min(imdb_encodings, dim=0).values) / (torch.max(imdb_encodings, dim=0).values - torch.min(imdb_encodings, dim=0).values)

tensor([[0.2434, 0.1588, 0.2429,  ..., 0.6478, 0.3624, 0.6247],
        [0.6009, 0.5922, 0.6268,  ..., 0.4393, 0.5885, 0.2374],
        [0.4238, 0.3586, 0.3466,  ..., 0.6571, 0.4455, 0.5175],
        ...,
        [0.7835, 0.6469, 0.7396,  ..., 0.6279, 0.5293, 0.3062],
        [0.5141, 0.7340, 0.8714,  ..., 0.4092, 0.4067, 0.6409],
        [0.5145, 0.5653, 0.9037,  ..., 0.3732, 0.7541, 0.2014]])