In [1]:
from datasets import load_dataset
from transformers import PerceiverFeatureExtractor
from torch.utils.data import DataLoader
import torch
device = torch.device("cpu")

# load cifar10 (only small portion for demonstration purposes) 
train_ds, test_ds = load_dataset('cifar10', split=['train[:1000]', 'test[:100]'])
# split up training into training + validation
splits = train_ds.train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']

id2label = {idx:label for idx,label in enumerate(train_ds.features['label'].names)}
label2id = {label:idx for idx, label in id2label.items()}

feature_extractor = PerceiverFeatureExtractor()

def preprocess_images(examples):
    examples['pixel_values'] = feature_extractor(examples['img'], return_tensors="pt").pixel_values
    return examples

# Set the transforms
train_ds.set_transform(preprocess_images)
val_ds.set_transform(preprocess_images)
test_ds.set_transform(preprocess_images)

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

train_batch_size = 10
eval_batch_size = 10

train_dataloader = DataLoader(train_ds, shuffle=True, collate_fn=collate_fn, batch_size=train_batch_size)
val_dataloader = DataLoader(val_ds, collate_fn=collate_fn, batch_size=eval_batch_size)
test_dataloader = DataLoader(test_ds, collate_fn=collate_fn, batch_size=eval_batch_size)

Found cached dataset cifar10 (/Users/stefanbroecker/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4)


  0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
import numpy as np

# t = next(iter(train_dataloader))["pixel_values"]

# torch.mean(t, dim=1).unsqueeze(1).shape

# torch.flatten(torch.mean(t, dim=1).unsqueeze(1), start_dim=2).shape

224


In [11]:
# preprocessor we customized to use the tagkop encoder
from tagkop_encoding_functions import (
    PerceiverImagePreprocessor
)
# perceiver modules from hugging face
from transformers.models.perceiver.modeling_perceiver import (
    PerceiverConfig,
    PerceiverModel,
    PerceiverClassificationDecoder
)
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
config = PerceiverConfig(image_size=224,
                        num_self_attends_per_block = 1,
                        num_cross_attention_heads = 1,
                        use_labels=True,
                        num_labels=10,
                        num_latents=64,
                        id2label=id2label,
                        label2id=label2id,
                        ignore_mismatched_sizes=True
                        )

preprocessor = PerceiverImagePreprocessor(
    config,
    prep_type="conv1x1",
    spatial_downsample=1,
    out_channels=64,
    position_encoding_type="tagkop",
    concat_or_add_pos="concat",
    project_pos_dim=64,
    tagkop_position_encoding_kwargs=dict(
        num_channels=64,
        index_dims=config.image_size**2,
    ),
)

modely = PerceiverModel(
    config,
    input_preprocessor=preprocessor,
    decoder=PerceiverClassificationDecoder(
        config,
        num_channels=config.d_latents,
        trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1),
        use_query_residual=True,
    ),
)

modely.to(device)

print('model: ', modely)

model:  PerceiverModel(
  (input_preprocessor): PerceiverImagePreprocessor(
    (convnet_1x1): Conv2d(1, 64, kernel_size=(1, 1), stride=(1, 1))
    (position_embeddings): PerceiverTagkopPositionEncoding()
    (positions_projection): Linear(in_features=64, out_features=64, bias=True)
    (conv_after_patches): Identity()
  )
  (embeddings): PerceiverEmbeddings()
  (encoder): PerceiverEncoder(
    (cross_attention): PerceiverLayer(
      (attention): PerceiverAttention(
        (self): PerceiverSelfAttention(
          (layernorm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (layernorm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (query): Linear(in_features=1280, out_features=128, bias=True)
          (key): Linear(in_features=128, out_features=128, bias=True)
          (value): Linear(in_features=128, out_features=128, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (output): PerceiverSelfOutput(
          (dense

In [13]:
from torch.optim import AdamW
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

optimizer = AdamW(modely.parameters(), lr=5e-5)

modely.train()
for epoch in range(10):
    for batch in tqdm(train_dataloader):
        # get the inputs; 
        inputs = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        # average the three image channels into a single channel
        # add the channel dimension back in because the model expects it
        outputs = modely(inputs=torch.mean(inputs, dim=1).unsqueeze(1))
        logits = outputs.logits

        criterion = torch.nn.CrossEntropyLoss()

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        predictions = outputs.logits.argmax(-1).cpu().detach().numpy()
        accuracy = accuracy_score(y_true=batch["labels"].numpy(), y_pred=predictions)
        print(f"Loss: {loss.item()}, Accuracy: {accuracy}")

    predictions = outputs.logits.argmax(-1).cpu().detach().numpy()
    accuracy = accuracy_score(y_true=batch["labels"].numpy(), y_pred=predictions)
    print(f"Loss: {loss.item()}, Accuracy: {accuracy}")
    
    

Epoch: 0


  0%|          | 0/90 [00:00<?, ?it/s]

Shit, something went wrong in generate_tagkop_features
torch.Size([50176, 64])
Loss: nan, Accuracy: 0.3
Shit, something went wrong in generate_tagkop_features
torch.Size([50176, 64])
Loss: nan, Accuracy: 0.0
Shit, something went wrong in generate_tagkop_features
torch.Size([50176, 64])
Loss: nan, Accuracy: 0.0
Shit, something went wrong in generate_tagkop_features
torch.Size([50176, 64])
Loss: nan, Accuracy: 0.0
Shit, something went wrong in generate_tagkop_features
torch.Size([50176, 64])


KeyboardInterrupt: 