In [1]:
%env CUDA_VISIBLE_DEVICES=0,1,3,4,5

env: CUDA_VISIBLE_DEVICES=0,1,3,4,5


In [13]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
from datasets import load_dataset
from transformers import PerceiverFeatureExtractor
from torch.utils.data import DataLoader
import torch
import numpy as np
print(f"CUDA: {torch.cuda.is_available()}")
print(torch.cuda.device_count())
device = torch.device('cuda')

CUDA: True
5


In [None]:
from datasets import load_from_disk

train_ds = load_from_disk("train_ds")
val_ds = load_from_disk("val_ds")
test_ds = load_from_disk("test_ds")

In [20]:


id2label = {idx:label for idx,label in enumerate(train_ds.features['label'].names)}
label2id = {label:idx for idx, label in id2label.items()}

feature_extractor = PerceiverFeatureExtractor()

def preprocess_images(examples):
    examples['pixel_values'] = feature_extractor(examples['img'], return_tensors="pt").pixel_values
    return examples

# Set the transforms
train_ds.set_transform(preprocess_images)
val_ds.set_transform(preprocess_images)
test_ds.set_transform(preprocess_images)

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

train_batch_size = 10
eval_batch_size = 10

train_dataloader = DataLoader(train_ds, shuffle=True, collate_fn=collate_fn, batch_size=train_batch_size)
val_dataloader = DataLoader(val_ds, collate_fn=collate_fn, batch_size=eval_batch_size)
test_dataloader = DataLoader(test_ds, collate_fn=collate_fn, batch_size=eval_batch_size)

Found cached dataset cifar10 (/home/jovyan/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4)


  0%|          | 0/2 [00:00<?, ?it/s]

In [21]:
# preprocessor we customized to use the tagkop encoder
from tagkop_encoding_functions import (
    PerceiverImagePreprocessor
)
# perceiver modules from hugging face
from transformers.models.perceiver.modeling_perceiver import (
    PerceiverConfig,
    PerceiverModel,
    PerceiverClassificationDecoder
)

In [27]:
config = PerceiverConfig(image_size=224,
                        
                         num_self_attends_per_block=13, #26
                         num_self_attention_heads=4, #8
                         num_cross_attention_heads=4,
                        use_labels=True,
                        num_labels=10,
                        num_latents=256,
                        id2label=id2label,
                        label2id=label2id,
                        ignore_mismatched_sizes=True
                        )

preprocessor = PerceiverImagePreprocessor(
    config,
    in_channels=3,
    prep_type="conv1x1",
    spatial_downsample=1,
    out_channels=64,
    position_encoding_type="tagkop",
    concat_or_add_pos="concat",
    project_pos_dim=64,
    tagkop_position_encoding_kwargs=dict(
        num_channels=64,
        index_dims=config.image_size**2,
    ),
)

modely = PerceiverModel(
    config,
    input_preprocessor=preprocessor,
    decoder=PerceiverClassificationDecoder(
        config,
        num_channels=config.d_latents,
        trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1),
        use_query_residual=True,
    ),
)

modely.to(device)
"Yay"
#print('model: ', modely)

tensor([[-0.1091, -0.1850, -0.2187,  ..., -0.0017, -0.3517, -0.2417],
        [-0.2510, -0.1296,  0.0905,  ...,  0.3434,  0.2693, -0.2364],
        [ 0.1538, -0.3202, -0.4974,  ..., -0.1070,  0.3269,  0.1330],
        ...,
        [ 0.4540, -0.5924,  0.3716,  ...,  0.3881, -0.0989,  0.3629],
        [ 0.5648,  0.0660, -0.0612,  ...,  0.5917,  0.0563, -0.0271],
        [ 0.2382, -0.3071, -0.1294,  ...,  0.0531, -0.1182,  0.0993]])


'Yay'

In [None]:
from torch.optim import AdamW
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

optimizer = AdamW(modely.parameters(), lr=5e-5)

modely.train()
for epoch in range(10):
    for batch in tqdm(train_dataloader):
        # get the inputs; 
        inputs = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = modely(inputs=inputs)
        logits = outputs.logits

        criterion = torch.nn.CrossEntropyLoss()

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        predictions = outputs.logits.argmax(-1).cpu().detach().numpy()
        accuracy = accuracy_score(y_true=batch["labels"].numpy(), y_pred=predictions)
        print(f"Loss: {loss.item()}, Accuracy: {accuracy}")

    predictions = outputs.logits.argmax(-1).cpu().detach().numpy()
    accuracy = accuracy_score(y_true=batch["labels"].numpy(), y_pred=predictions)
    print(f"Loss: {loss.item()}, Accuracy: {accuracy}")
    
    

Loss: 1.8415215015411377, Accuracy: 0.2
Loss: 1.389484167098999, Accuracy: 0.4
Loss: 1.3805886507034302, Accuracy: 0.5
Loss: 1.7398860454559326, Accuracy: 0.3
Loss: 1.3848365545272827, Accuracy: 0.5
Loss: 1.814871072769165, Accuracy: 0.4
Loss: 1.305907130241394, Accuracy: 0.6
Loss: 1.5927965641021729, Accuracy: 0.5
Loss: 1.833509087562561, Accuracy: 0.3
Loss: 1.638188123703003, Accuracy: 0.4
Loss: 2.0216116905212402, Accuracy: 0.4
Loss: 1.457200527191162, Accuracy: 0.5
Loss: 1.088021993637085, Accuracy: 0.5
Loss: 1.7653058767318726, Accuracy: 0.4
Loss: 1.6695318222045898, Accuracy: 0.4
Loss: 1.6956520080566406, Accuracy: 0.3
Loss: 1.674496054649353, Accuracy: 0.6
Loss: 1.761444091796875, Accuracy: 0.3
Loss: 2.1755051612854004, Accuracy: 0.2
Loss: 1.7349202632904053, Accuracy: 0.3
Loss: 1.9460941553115845, Accuracy: 0.2
Loss: 1.459187388420105, Accuracy: 0.4
Loss: 2.0537240505218506, Accuracy: 0.2
Loss: 1.5740034580230713, Accuracy: 0.4
Loss: 1.4206186532974243, Accuracy: 0.3
Loss: 1.98

  0%|          | 0/4500 [00:00<?, ?it/s]

Loss: 1.5868393182754517, Accuracy: 0.5
Loss: 1.3866225481033325, Accuracy: 0.5
Loss: 1.6558997631072998, Accuracy: 0.5
Loss: 1.4872163534164429, Accuracy: 0.6
Loss: 1.9771442413330078, Accuracy: 0.3
Loss: 1.7501121759414673, Accuracy: 0.4
Loss: 1.102379560470581, Accuracy: 0.6
Loss: 1.7696377038955688, Accuracy: 0.1
Loss: 1.535346269607544, Accuracy: 0.4
Loss: 1.518048882484436, Accuracy: 0.3
Loss: 1.7184178829193115, Accuracy: 0.2
Loss: 1.4421888589859009, Accuracy: 0.3
Loss: 1.5465033054351807, Accuracy: 0.3
Loss: 1.2427157163619995, Accuracy: 0.6
Loss: 1.545809268951416, Accuracy: 0.3
Loss: 1.7708022594451904, Accuracy: 0.4
Loss: 1.280522108078003, Accuracy: 0.5
Loss: 1.255613088607788, Accuracy: 0.7
Loss: 1.3406115770339966, Accuracy: 0.4
Loss: 1.5129401683807373, Accuracy: 0.4
Loss: 1.1900205612182617, Accuracy: 0.7
Loss: 1.4860560894012451, Accuracy: 0.4
Loss: 1.7710659503936768, Accuracy: 0.3
Loss: 1.4863897562026978, Accuracy: 0.5
Loss: 1.781369924545288, Accuracy: 0.3
Loss: 1