In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3"

import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                  num_labels=8)

In [None]:
from torchaffectnet import AffectNetDataset
from torchvision.transforms import (Compose,
                                    Normalize,
                                    Resize,
                                    RandomResizedCrop,
                                    RandomHorizontalFlip,
                                    RandomApply,
                                    ColorJitter,
                                    RandomGrayscale,
                                    ToTensor,
                                    RandomAffine)

normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)

transform = Compose([
        Resize(tuple(feature_extractor.size.values())),
        ToTensor(),
    ])

# transform = Compose([
#         RandomResizedCrop(size=tuple(
#             feature_extractor.size.values()), scale=(0.2, 1.)),
#         RandomHorizontalFlip(),
#         RandomApply([
#             ColorJitter(0.4, 0.4, 0.4, 0.1)
#         ], p=0.8),
#         ToTensor(),
#         # normalize
#     ])


emotion_dataset = AffectNetDataset('../../Affectnet/validation.csv',
                                   '../../Affectnet/Manually_Annotated/Manually_Annotated_Images/',
                                   transform=transform,
                                   mode='classification')

In [None]:
from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    f"test",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=5,
    per_device_eval_batch_size=1,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir='logs',
    remove_unused_columns=False,
    no_cuda=True
)
args.device

In [None]:
from torchaffectnet.collators import Collator

trainer = Trainer(
    model,
    args,
    train_dataset=emotion_dataset,
    data_collator=Collator(),
    tokenizer=feature_extractor,
)

In [None]:
trainer.train()

In [None]:
from tqdm import tqdm

def head_outputs(model, dataset, device):
    features = []
    labels = []
    for img, label in tqdm(dataset):
        if isinstance(img, tuple):
            img = img[0]
        with torch.no_grad():
            feature = model(img.unsqueeze(0).to(device)).logits
        # print(feature)
        features.append(feature.cpu())
        labels.append(label)
    return torch.stack(features).squeeze(), torch.tensor(labels)

In [None]:
features, labels = head_outputs(model.to(device), dataset, device)
# features, labels = CLS_tokens(model.to(device), dataset, device)

In [None]:
from utils import exclude_id

id2label = {
            0: 'valence < -0.5',
            1: '-0.5 <= valence <= 0.5',
            2: '0.5 < valence',
        }
fig = plot_tokens_category(features, labels, 20, id2label, 0)