In [1]:
import os, sys
sys.path += [os.path.dirname('../scripts/')]

In [2]:
model_ckpt = 'CL_valaro_z64_bs512'

In [3]:
from transformers import ViTFeatureExtractor, ViTForImageClassification

feature_extractor = ViTFeatureExtractor.from_pretrained('../outputs/' + model_ckpt + '/model')

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from dataset import AlternatingDataset, AffectNetDatasetForSupConWithValenceArousal
from torchaffectnet import AffectNetDatasetForSupCon
from torchvision.transforms import (Compose,
                                    Normalize,
                                    Resize,
                                    RandomResizedCrop,
                                    RandomHorizontalFlip,
                                    RandomApply,
                                    ColorJitter,
                                    RandomGrayscale,
                                    ToTensor,
                                    RandomAffine)

normalize = Normalize(mean=feature_extractor.image_mean,
                          std=feature_extractor.image_std)

transform1 = Compose([
    RandomAffine(30),
    Resize(tuple(feature_extractor.size.values())),
    ToTensor(),
    normalize,
])

transform2 = Compose([
    RandomResizedCrop(size=tuple(
        feature_extractor.size.values()), scale=(0.2, 1.)),
    RandomHorizontalFlip(),
    RandomApply([
        ColorJitter(0.4, 0.4, 0.4, 0.1)
    ], p=0.8),
    ToTensor(),
    normalize
])

valaro_dataset = AffectNetDatasetForSupConWithValenceArousal('../../Affectnet/validation.csv',
                                                             '../../Affectnet/Manually_Annotated/Manually_Annotated_Images/',
                                                             transform1=transform1,
                                                             transform2=transform2,
                                                             exclude_label=[8,9,10])
expression_dataset = AffectNetDatasetForSupCon('../../Affectnet/validation.csv',
                                               '../../Affectnet/Manually_Annotated/Manually_Annotated_Images/',
                                               transform1=transform1,
                                               transform2=transform2,
                                               exclude_label=[8,9,10])

dataset = AlternatingDataset(valaro_dataset, expression_dataset)
dataset

<dataset.AlternatingDataset at 0x7fd4910b5630>

In [6]:
dataset[0]

(((tensor([[[-1., -1., -1.,  ..., -1., -1., -1.],
            [-1., -1., -1.,  ..., -1., -1., -1.],
            [-1., -1., -1.,  ..., -1., -1., -1.],
            ...,
            [-1., -1., -1.,  ..., -1., -1., -1.],
            [-1., -1., -1.,  ..., -1., -1., -1.],
            [-1., -1., -1.,  ..., -1., -1., -1.]],
   
           [[-1., -1., -1.,  ..., -1., -1., -1.],
            [-1., -1., -1.,  ..., -1., -1., -1.],
            [-1., -1., -1.,  ..., -1., -1., -1.],
            ...,
            [-1., -1., -1.,  ..., -1., -1., -1.],
            [-1., -1., -1.,  ..., -1., -1., -1.],
            [-1., -1., -1.,  ..., -1., -1., -1.]],
   
           [[-1., -1., -1.,  ..., -1., -1., -1.],
            [-1., -1., -1.,  ..., -1., -1., -1.],
            [-1., -1., -1.,  ..., -1., -1., -1.],
            ...,
            [-1., -1., -1.,  ..., -1., -1., -1.],
            [-1., -1., -1.,  ..., -1., -1., -1.],
            [-1., -1., -1.,  ..., -1., -1., -1.]]]),
   tensor([[[-0.1294, -0.1294, -0.10

In [11]:
from typing import Any, Dict
from torchaffectnet.collators import Collator
import torch


class AlternatingCollator(Collator):
    def __init__(self, return_labels=[True, True]) -> None:
        super().__init__()
        self.return_labels1 = return_labels[0]
        self.return_labels2 = return_labels[1]
    
    def collate_fn(self, examples) -> Dict[str, Any]:
        data1, data2 = zip(*examples)
        if self.return_labels1:
            data1_imgs, data1_targets = zip(*data1)
            data1_targets = torch.stack(data1_targets)
        else:
            data1_imgs = data1
        
        if self.return_labels2:
            data2_imgs, data2_targets = zip(*data2)
            data2_targets = torch.stack(data2_targets)
        else:
            data2_imgs = data2
        
        data1_imgs1, data1_imgs2 = zip(*data1_imgs)
        data1_imgs1 = torch.stack(data1_imgs1)
        data1_imgs2 = torch.stack(data1_imgs2)
        
        data2_imgs1, data2_imgs2 = zip(*data2_imgs)
        data2_imgs1 = torch.stack(data2_imgs1)
        data2_imgs2 = torch.stack(data2_imgs2)
        
        pixel_values1 = torch.cat([data1_imgs1, data1_imgs2])
        pixel_values2 = torch.cat([data2_imgs1, data2_imgs2])
        
        output = []
        if self.return_labels1:
            output.append(
                {'pixel_values': pixel_values1, 'labels': data1_targets}
            )
        else:
            output.append({'pixel_values': pixel_values1})
        if self.return_labels2:
            output.append(
                {'pixel_values': pixel_values2, 'labels': data2_targets}
            )
        else:
            output.append({'pixel_values': pixel_values2})
        
        return output

In [15]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(dataset, collate_fn=AlternatingCollator(), batch_size=4)

batch = next(iter(train_dataloader))
print(batch[0]['pixel_values'].shape)
print(batch[0]['labels'].shape)
print(batch[1]['pixel_values'].shape)
print(batch[1]['labels'].shape)

torch.Size([8, 3, 224, 224])
torch.Size([4, 2])
torch.Size([8, 3, 224, 224])
torch.Size([4])
