## 0. Load some necessary packages

In [None]:
import numpy as np
import os
import copy
import shutil
import pickle

import torch
import torch.nn as nn
from datasets import load_dataset
from transformers import BeitImageProcessor
from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor)
from torch.utils.data import DataLoader
from transformers import TrainingArguments, Trainer
from sklearn.metrics import accuracy_score

In [None]:
# Use the GPU
#os.environ["CUDA_VISIBLE_DEVICES"]="1"
torch.cuda.is_available()

## 1. Face detection and cropping

## Load the cropped dataset (TUFTS)

In [None]:
from datasets import load_dataset

#load TUFTS
train_ds = load_dataset("imagefolder", data_dir="crop_TUFTS", split="train")
np.shape(train_ds)

In [None]:
# Let's print out the dataset:
print('Dataset info:' ,train_ds)

# We can also check out the features of the dataset in more detail:
print('Dataset features: ', train_ds.features)

## 2. Preprocessing the data

In [None]:
from transformers import BeitImageProcessor

# Load the pre-trained model: Self-supervised on ImageNet-22k (14 million images, 21,841 classes) at resolution 224x224
# and fine-tuned on the same dataset at resolution 224x224.
processor = BeitImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k")

In [None]:
from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor)

image_mean, image_std = processor.image_mean, processor.image_std
size = processor.size["height"]

normalize = Normalize(mean=image_mean, std=image_std)
_train_transforms = Compose(
        [
            RandomResizedCrop(size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

_val_transforms = Compose(
        [
            Resize(size),
            CenterCrop(size),
            ToTensor(),
            normalize,
        ]
    )

def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

In [None]:
# Set the transforms
train_ds.set_transform(train_transforms)
val_ds.set_transform(val_transforms)

In [None]:
from torch.utils.data import DataLoader
import torch

# Create a corresponding PyTorch DataLoader
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

train_dataloader = DataLoader(train_ds, collate_fn=collate_fn, batch_size=4)

## 3. Define the model

In [None]:
from transformers import BeitForImageClassification

model = BeitForImageClassification.from_pretrained('trainer') # Pre-trained BEFiT-V

In [None]:
from transformers import TrainingArguments, Trainer

metric_name = "accuracy"

# We define the class `TrainingArguments` containing all the attributes to customize the training. 
# It requires one folder name, which will be used to save the checkpoints of the model, and all other arguments are optional.
args = TrainingArguments(
    f"test_transformers_thermal",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=2e-3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=16,
    num_train_epochs=150,
    weight_decay=0.05, # Regularization that penalizes large weights. Adds a term to the loss function proportional to the sum of the squared weights. Prevents the weights from growing too large.
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    logging_dir='logs_thermal',
    remove_unused_columns=False,
)
# Here we set the evaluation to be done at the end of each epoch,
# tweak the learning rate, set the training and evaluation batch_sizes and
# customize the number of epochs for training, as well as the weight decay.
# We also set the argument "remove_unused_columns" to False, because otherwise the "image" column would be removed, 
# which is required for the data transformations.

In [None]:
from sklearn.metrics import accuracy_score

# We also define a `compute_metrics` function that will be used to compute metrics at evaluation. We use "accuracy" here.
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return dict(accuracy=accuracy_score(predictions, labels))

In [None]:
# Pass all the information to the trainer
trainer_thermal = Trainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

In [None]:
# We can now finetune our model by just calling the `train` method:
trainer_thermal.train()

In [None]:
trainer_thermal.save_model("trainer_thermal")