In [1]:
import os, sys
sys.path += [os.path.dirname('../scripts/')]

In [2]:
model_ckpt = 'CL_valaro_z64_bs512'

In [3]:
from transformers import ViTFeatureExtractor, ViTForImageClassification

feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from dataset import AlternatingDataset, AffectNetDatasetForSupConWithValenceArousal
from torchaffectnet import AffectNetDatasetForSupCon
from torchvision.transforms import (Compose,
                                    Normalize,
                                    Resize,
                                    RandomResizedCrop,
                                    RandomHorizontalFlip,
                                    RandomApply,
                                    ColorJitter,
                                    RandomGrayscale,
                                    ToTensor,
                                    RandomAffine)

normalize = Normalize(mean=feature_extractor.image_mean,
                          std=feature_extractor.image_std)

transform1 = Compose([
    RandomAffine(30),
    Resize(tuple(feature_extractor.size.values())),
    ToTensor(),
    normalize,
])

transform2 = Compose([
    RandomResizedCrop(size=tuple(
        feature_extractor.size.values()), scale=(0.2, 1.)),
    RandomHorizontalFlip(),
    RandomApply([
        ColorJitter(0.4, 0.4, 0.4, 0.1)
    ], p=0.8),
    ToTensor(),
    normalize
])

valaro_dataset = AffectNetDatasetForSupConWithValenceArousal('../../Affectnet/validation.csv',
                                                             '../../Affectnet/Manually_Annotated/Manually_Annotated_Images/',
                                                             transform1=transform1,
                                                             transform2=transform2,
                                                             exclude_label=[8,9,10])
expression_dataset = AffectNetDatasetForSupCon('../../Affectnet/validation.csv',
                                               '../../Affectnet/Manually_Annotated/Manually_Annotated_Images/',
                                               transform1=transform1,
                                               transform2=transform2,
                                               exclude_label=[8,9,10])

dataset = AlternatingDataset(valaro_dataset, expression_dataset)
dataset

<dataset.AlternatingDataset at 0x7f464990ae60>

In [5]:
from torch.utils.data import DataLoader
from dataset import AlternatingCollator

train_dataloader = DataLoader(dataset, collate_fn=AlternatingCollator(), batch_size=4)

batch = next(iter(train_dataloader))
print(batch[0]['pixel_values'].shape)
print(batch[0]['labels'].shape)
print(batch[1]['pixel_values'].shape)
print(batch[1]['labels'].shape)

torch.Size([8, 3, 224, 224])
torch.Size([4, 2])
torch.Size([8, 3, 224, 224])
torch.Size([4])
