<a href="https://colab.research.google.com/github/ilsilfverskiold/smaller-models-docs/blob/main/computer-vision/cook/image-classification/ViT_Huggingface_Custom_Trainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Image classification using ViT/Swin Transformer with a Custom Hugging Face Trainer**

---

The pre-trained model we'll fine-tune here is set for a ViT model - google/vit-base-patch16-224 - but should work well for any SWIN Transformer/ConvNEXT model as well. I've customized the trainer here for weight imbalances between the classes. Use it as a starting point.

Batch size is 32, epoch is 5 make sure to change these values to your preferences.

**Make sure you change the dataset to what you need.** My dataset I've used has both a training and a validation set, so change the code accordingly if you don't have a validation set.

In [None]:
!pip install -q transformers datasets accelerate

In [None]:
dataset_url = "ilsilfverskiold/traffic-camera-norway-images" # public dataset
model_checkpoint = "google/vit-base-patch16-224" # decide on your pre-trained model - see the huggingface hub
new_model_name = 'traffic-image-classification'
learning_rate = 5e-5
weight_decay = 0.01
epochs = 5
batch_size= 32

Fetch dataset from huggingface or import one from somewhere else. Make sure it has been properly processed before though so the images are in PIL format.Look into the cook book for processing and pushing a custom image dataset if it's new to you.

In [None]:
from datasets import load_dataset

dataset = load_dataset(dataset_url) # to fetch a private dataset use token=your_token

dataset

Look into the labels and set them so we can prepare the pre-trained model with them for fine-tuning.

In [None]:
labels = dataset["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label

id2label[2]

Preprocess the dataset for fine-tuning with ViT/ConvNEXT/Swin Transformer we'll use an image prcoessor to normalize. The image processor ensures that every input image conforms to expectations (input image size and pixel value range).

In [None]:
from transformers import AutoImageProcessor

image_processor  = AutoImageProcessor.from_pretrained(model_checkpoint)
image_processor

The code below is defining a set of image transformations that are applied to the training data. These transformations prepare images for input into a neural network by normalizing them and augmenting the dataset to improve model robustness.

In [None]:
from torchvision.transforms import (
    Compose,
    Resize,
    Normalize,
    CenterCrop,
    RandomHorizontalFlip,
    RandomResizedCrop,
    ToTensor,
)

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)

train_transform = Compose([
    Resize(256),
    CenterCrop(224),
    RandomHorizontalFlip(),
    ToTensor(),
    normalize,
])

val_transform = Compose([
    Resize(256),
    CenterCrop(224),
    ToTensor(),
    normalize,
])

def apply_transform(examples, transform):
    examples['pixel_values'] = [transform(image.convert('RGB')) for image in examples['image']]
    return examples

def set_dataset_transform(dataset, transform):
    dataset.set_transform(lambda examples: apply_transform(examples, transform))

In [None]:
set_dataset_transform(dataset['train'], train_transform)
set_dataset_transform(dataset['validation'], val_transform)

Check that we now have another field called pixel_values for each item below.

In [None]:
dataset['train'][0]

We use the labels we set up earlier from the dataset when importing the pre-trained model below, we also tell it to ignore the pre-defined labels that it previously have been trained on.

In [None]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    model_checkpoint,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes = True,  # set to true to ignore the pre-defined labels
)

We'll set up a function to calculate the class weights between the classes.

In [None]:
import numpy as np
import torch

def compute_class_weights(dataset):
    labels = np.array([example['label'] for example in dataset])
    class_counts = np.bincount(labels, minlength=4)
    total_samples = len(labels)
    class_weights = total_samples / (4 * np.maximum(class_counts, np.ones_like(class_counts)))
    return torch.tensor(class_weights, dtype=torch.float32, device=model.device)

class_weights = compute_class_weights(dataset['train'])

print("Class Weights:", class_weights)

Class Weights: tensor([1.2344, 0.4786, 1.2344, 3.4443])


Set up your training metrics below.

**Accuracy** indicates overall correctness, **precision** measures the reliability of positive predictions, **recall** assesses the model's ability to identify all positive samples, and **F1 score** balances precision and recall, crucial in cases of class imbalance.

To understand this, if precision is relatively high, suggesting that when the model predicts an instance as positive, it is likely to be correct. However, if the recall is somewhat lower, this indicates that the model misses a significant portion of actual positive cases.

To put it into perspective, for complex tasks, **especially those involving highly imbalanced datasets** or where distinguishing classes is inherently challenging, an F1 score around 0.75 - 0.80 can be considered quite ok.

You'll need at least accuracy here though if you are considering to remove some of the metrics.

In [None]:
import numpy as np
from datasets import load_metric

accuracy_metric = load_metric("accuracy")
precision_metric = load_metric("precision")
recall_metric = load_metric("recall")
f1_metric = load_metric("f1")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=1)

    accuracy = accuracy_metric.compute(predictions=predictions, references=labels)
    precision = precision_metric.compute(predictions=predictions, references=labels, average='macro')
    recall = recall_metric.compute(predictions=predictions, references=labels, average='macro')
    f1 = f1_metric.compute(predictions=predictions, references=labels, average='macro')

    metrics = {
        "accuracy": accuracy['accuracy'],
        "precision": precision['precision'],
        "recall": recall['recall'],
        "f1": f1['f1']
    }
    return metrics

The purpose of the collate_fn function below is to control how a list of samples (gathered from the dataset) is merged into a single batch. This function is crucial for ensuring that batches are structured properly before being fed into a model during training or evaluation.

In [None]:
import torch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

Set up your training arguments for the Hugging Face trainer. Leave a lot of it as you can stumble onto errors if you don't. Nevertheless, you may want to play around with the learning rate, batch size and epochs used (set them at the start of this notebook)

In [None]:
args = TrainingArguments(
    f"{new_model_name}",
    remove_unused_columns=False,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=epochs,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)

**The CustomTrainer** below modifies the standard training process to include the use of weighted loss for handling class imbalance in the dataset. CustomTrainer extends the Trainer class, allowing it to utilize all the base functionalities of Trainer while overriding the standard training process by using a special approach to make sure the model pays proper attention underrepresented classes in a dataset.

In [None]:
from torch import nn

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        device = model.device
        outputs = model(**inputs)
        logits = outputs.get('logits')
        class_weights_device = class_weights.to(device)
        loss_fct = nn.CrossEntropyLoss(weight=class_weights_device)
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

trainer = CustomTrainer(
    model,
    args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['validation'],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)


**Train the model.** Remember to pay attention to the training loss and validation loss, both should consistently go down and if validation keeps going up while training loss keeps going down you may be overfitting the model. Accuracy should obviously go up, and if you see very small marginal increases for every epoch then you might have reached the limit of what you can achieve.

Don't worry too much if it fluctuates a bit, and try the model afterwards to see how it does as well. Remember that if the dataset is not great then no process will make the model great so go back and look over the dataset if the model is behaving poorly.

In [None]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
0,0.9021,0.771184,0.683374,0.623743,0.714295,0.649003
2,0.3954,0.491489,0.798289,0.752837,0.797007,0.771331
4,0.3139,0.434835,0.810513,0.764115,0.811026,0.782088


***** train metrics *****
  epoch                    =       4.9565
  total_flos               = 6019499350GF
  train_loss               =       0.5675
  train_runtime            =   0:27:21.70
  train_samples_per_second =       22.407
  train_steps_per_second   =        0.174


In [None]:
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  epoch                   =     4.9565
  eval_accuracy           =     0.8105
  eval_f1                 =     0.7821
  eval_loss               =     0.4348
  eval_precision          =     0.7641
  eval_recall             =      0.811
  eval_runtime            = 0:00:15.92
  eval_samples_per_second =     51.369
  eval_steps_per_second   =      1.633


We'll now save the model so we can do some inference on it.

In [None]:
trainer.save_model('new_model')

In [None]:
from transformers import pipeline

pipe = pipeline('image-classification', model='new_model')

I'm importing images from my Google Drive to test against. This is completely optional.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from PIL import Image

image_path = '/content/drive/MyDrive/my_image_I_want_to_test.png'

image = Image.open(image_path)

results = pipe(image)
results

(Optional) I will also run it against a few images in the validation set to see what the results are.

In [None]:
from PIL import Image

for i in range(100):
    image_data = dataset['validation'][i]['image']
    label_index = dataset['validation'][i]['label']

    if not isinstance(image_data, Image.Image):
        image = Image.open(image_data)
    else:
        image = image_data

    results = pipe(image)

    print(f"Results for image {i+1}:")
    print(results)
    print("Actual label:", id2label[label_index])
    print("----------------------------------")

If you're satisfied, log into HuggingFace with a token that you can get via Settings in your Hugging Face account. Remember that it needs both read and write access. It will ask you for this token below.

In [None]:
!huggingface-cli login

In [None]:
trainer.push_to_hub(new_model_name) # set Private=True if you want the model as private