# ViT Fine-Tuning Demo

*Based on: https://huggingface.co/blog/fine-tune-vit*

Before running, download the [TrashNet dataset](https://github.com/garythung/trashnet/raw/master/data/dataset-resized.zip) and extract it to the 'dataset' folder.

In [1]:
%pip install huggingface
%pip install datasets
%pip install pillow
%pip install transformers
%pip install scikit-learn
%pip install transformers[torch]
%pip install tensorboardX
# PyTorch + CUDA should be installed manually



Note: you may need to restart the kernel to use updated packages.














Note: you may need to restart the kernel to use updated packages.




Note: you may need to restart the kernel to use updated packages.












Note: you may need to restart the kernel to use updated packages.




Note: you may need to restart the kernel to use updated packages.


















Note: you may need to restart the kernel to use updated packages.




Note: you may need to restart the kernel to use updated packages.


## Dataset Processing

In [2]:
from datasets import load_dataset

ds = load_dataset("imagefolder", data_dir="/home/schakkera/CSE512/dataset-resized", split="train")
# ds = load_dataset("imagefolder", data_dir="/home/starc/SBU/Sem-3/ML/Project/dataset-resized/dataset-resized", split="train")
ds = ds.train_test_split(test_size=0.2, seed=512)

Resolving data files:   0%|          | 0/2527 [00:00<?, ?it/s]

In [3]:
from transformers import ViTImageProcessor, ViTConfig
from torchvision.transforms import Compose, RandomResizedCrop, RandomHorizontalFlip, ToTensor, Normalize, Resize, CenterCrop, ConvertImageDtype, InterpolationMode, ColorJitter
import torch

model_name_or_path = 'google/vit-base-patch16-224'
config = ViTConfig.from_pretrained(model_name_or_path)
processor = ViTImageProcessor.from_pretrained(model_name_or_path)

_train_transform = Compose([
    RandomResizedCrop(config.image_size, interpolation=InterpolationMode.BICUBIC, antialias=None),
    RandomHorizontalFlip(),
    ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    ToTensor(),
    ConvertImageDtype(torch.float),
    Normalize(
        mean=processor.image_mean,
        std=processor.image_std
    )
])

_val_transform = Compose([
    Resize(config.image_size, interpolation=InterpolationMode.BICUBIC, antialias=None),
    CenterCrop(config.image_size),
    ToTensor(),
    ConvertImageDtype(torch.float),
    Normalize(
        mean=processor.image_mean,
        std=processor.image_std
    )
])

def train_transform(example_batch):
    example_batch['pixel_values'] = [
        _train_transform(img.convert('RGB')) for img in example_batch['image']
    ]
    return example_batch

def val_transform(example_batch):
    example_batch['pixel_values'] = [
        _val_transform(img.convert('RGB')) for img in example_batch['image']
    ]
    return example_batch

ds['train'].set_transform(train_transform)
ds['test'].set_transform(val_transform)

## Training

In [4]:
import torch

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

In [5]:
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

def ova_compute_metrics(p):
    return metric.compute(predictions=np.argmin(p.predictions, axis=1), references=p.label_ids)

  metric = load_metric("accuracy")


In [6]:
from transformers import ViTForImageClassification

labels = ds['train'].features['label'].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    ignore_mismatched_sizes=True,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([6]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([6, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
from transformers import TrainingArguments
from transformers.trainer_utils import get_last_checkpoint
import os

if os.path.exists('./vit-base-trash-ovareg'):
    last_checkpoint = get_last_checkpoint("./vit-base-trash-ovareg")
else:
    last_checkpoint = None

training_args = TrainingArguments(
  output_dir="./vit-base-trash-ovareg",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=32,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=20,
  learning_rate=3e-5, #3e-5
  save_total_limit=1,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
  resume_from_checkpoint=last_checkpoint
)


# if os.path.exists('./vit-base-trash-csce'):
#     last_checkpoint = get_last_checkpoint("./vit-base-trash-csce")
# else:
#     last_checkpoint = None

# training_args = TrainingArguments(
#   output_dir="./vit-base-trash-csce",
#   per_device_train_batch_size=16,
#   evaluation_strategy="steps",
#   num_train_epochs=32,
#   fp16=True,
#   save_steps=100,
#   eval_steps=100,
#   logging_steps=20,
#   learning_rate=3e-5, #3e-5
#   save_total_limit=1,
#   remove_unused_columns=False,
#   push_to_hub=False,
#   report_to='tensorboard',
#   load_best_model_at_end=True,
#   resume_from_checkpoint=last_checkpoint
# )

In [8]:
from transformers import Trainer
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class WeightedCrossEntropyTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.cost_matrix = torch.tensor([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1], 
                                       [0.1, 0.1, 0.4, 0.1, 0.3, 0.1], 
                                       [0.1, 0.4, 0.1, 0.1, 0.1, 0.2], 
                                       [0.1, 0.1, 0.1, 0.1, 0.1, 0.1], 
                                       [0.1, 0.4, 0.5, 0.1, 0.1, 0.2], 
                                       [0.4, 0.1, 0.2, 0.4, 0.1, 0.1]]).to(device)
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        # print(labels.size(), labels)
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        exp_logits = torch.exp(logits)
        batch_spec_weights = self.cost_matrix[labels, :]
        weighted_exp = exp_logits * batch_spec_weights
        weighted_softmax = weighted_exp/torch.sum(weighted_exp, axis=1).reshape(weighted_exp.size(0), -1)
        log_weighted = torch.log(weighted_softmax)
        onehot = F.one_hot(labels, self.cost_matrix.size(0))
        ind_cross_entropy = -log_weighted*onehot
        
        # compute custom loss (suppose one has 3 labels with different weights)
        loss = torch.mean(torch.sum(ind_cross_entropy, axis=1))
        return (loss, outputs) if return_outputs else loss

class OVARegressionTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.cost_matrix = torch.tensor([[0.0, 5., 5., 5., 5., 5.], 
                                       [5., 0.0, 20., 5., 15., 5.], 
                                       [5., 20., 0.0, 5., 5., 10.], 
                                       [5., 5., 5., 0.0, 5., 5.], 
                                       [5., 5., 25., 5., 0.0, 10.], 
                                       [20., 5., 10., 20., 5., 0.0]]).to(device)
    
    
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        # print(labels.size(), labels)
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        onehot = F.one_hot(labels, self.cost_matrix.size(0))
        zval = (onehot*2)-1
        batch_spec_weights = self.cost_matrix[labels, :]
        regress = logits-batch_spec_weights
        # print(np.argmin(regress.detach().numpy(), axis=1), labels)
        zprod = zval*regress
        
        # compute custom loss (suppose one has 3 labels with different weights)
        loss=torch.mean(torch.sum(torch.log(torch.exp(zprod)+1), axis=1))
        # print(loss)
        return (loss, outputs) if return_outputs else loss

In [9]:
from transformers import Trainer

# trainer = Trainer(
#     model=model,
#     args=training_args,
#     data_collator=collate_fn,
#     compute_metrics=compute_metrics,
#     train_dataset=ds["train"],
#     eval_dataset=ds["test"],
#     tokenizer=processor,
# )
trainer = OVARegressionTrainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=ova_compute_metrics,
    train_dataset=ds["train"],
    eval_dataset=ds["test"],
    tokenizer=processor,
)
# trainer = WeightedCrossEntropyTrainer(
#     model=model,
#     args=training_args,
#     data_collator=collate_fn,
#     compute_metrics=compute_metrics,
#     train_dataset=ds["train"],
#     eval_dataset=ds["test"],
#     tokenizer=processor,
# )

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [10]:
train_results = trainer.train(resume_from_checkpoint=last_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

Step,Training Loss,Validation Loss,Accuracy
100,17.7477,17.487024,0.535573
200,13.212,13.199078,0.804348
300,11.1894,10.415431,0.857708
400,8.9667,8.52471,0.895257
500,7.2102,7.171139,0.907115
600,5.98,5.86706,0.903162
700,5.0633,4.94689,0.918972
800,3.7769,3.779551,0.928854
900,3.3792,3.444514,0.928854
1000,2.8676,2.961747,0.93083


***** train metrics *****
  epoch                    =         32.0
  total_flos               = 4667548109GF
  train_loss               =       2.9419
  train_runtime            =   0:30:07.29
  train_samples_per_second =       35.784
  train_steps_per_second   =        2.249


In [11]:
metrics = trainer.evaluate(ds['test'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  epoch                   =       32.0
  eval_accuracy           =     0.9565
  eval_loss               =     1.3741
  eval_runtime            = 0:00:07.44
  eval_samples_per_second =     67.926
  eval_steps_per_second   =      8.591
