# **Project Computer Vision**
## **Image Classification**

In [1]:
import io
from PIL import Image

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from torchvision import transforms as T

from huggingface_hub import notebook_login

from datasets import load_dataset, DatasetDict
import kagglehub

from transformers import AutoImageProcessor, ViTForImageClassification
from transformers import Trainer, TrainingArguments

import evaluate

In [None]:
# Hugging Face Login für Modell-Upload
notebook_login()

## Load the dataset
---

In [None]:
# Datensatz von Kaggle herunterladen
path = kagglehub.dataset_download("iamsouravbanerjee/animal-image-dataset-90-different-animals")

print("Path to dataset files:", path)


In [None]:
# Datensatz als HF-Dataset laden
dataset = load_dataset("imagefolder", data_dir=path)

print(dataset)

#### The dataset contains the following features:
- image: the image stored in PIL format
- label: the ID of a folder in the dataset. the folder ist named with the species of the animal and contains corresponding images.

In [None]:
# Erstes Beispiel aus dem Trainingsdatensatz laden und zeigen
example = dataset['train'][0]
print(example)

In [None]:
# Anzahl Klassen ermitteln und IDs sowie Namen ausgeben
unique_labels = dataset['train'].unique('label')
print("Anzahl Klassen:", len(unique_labels))

print("Label-IDs:", unique_labels)

label_names = dataset['train'].features['label'].names
print("Labelnamen:", label_names)


#### Some sample images of our dataset

In [None]:
# Zufällig ausgewählte Bilder visualisieren
def show_samples(ds, rows, cols):
    samples = ds.shuffle(seed=42).select(np.arange(rows * cols))
    fig = plt.figure(figsize=(cols * 4, rows * 4))
    
    label_names = ds.features['label'].names

    for i in range(rows * cols):
        img = samples[i]['image']
        label_id = samples[i]['label']
        label_name = label_names[label_id]

        ax = fig.add_subplot(rows, cols, i + 1)
        ax.imshow(img)
        ax.set_title(label_name)
        ax.axis('off')

    plt.tight_layout()
    plt.show()

show_samples(dataset['train'], rows=3, cols=5)

## Preprocessing the Dataset
---

In [None]:
# Trainingsdaten aufteilen
from datasets import DatasetDict

split_dataset = dataset['train'].train_test_split(test_size=0.2, seed=42)
eval_dataset = split_dataset['test'].train_test_split(test_size=0.5, seed=42)

our_dataset = DatasetDict({
    'train': split_dataset['train'],
    'validation': eval_dataset['train'],
    'test': eval_dataset['test']
})

print(our_dataset)

#### Test for species in different sets

Because of the structure of the dataset, I want to check whether all 90 animal classes are represented across the train, validation, and test splits.

In [None]:
# Testen, ob in allen Splits genügend Tierarten vorhanden sind
for split in ['train', 'validation', 'test']:
    labels_in_split = set(our_dataset[split]['label'])
    print(f"{split}: {len(labels_in_split)} Klassen")

### Image Processor

In [None]:
# Hugging Face Image Processor laden (ViT erwartet 224x224 + Normalisierung)
processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224')
processor

#### Data Augmentation

In [37]:
# Datenaugmentation (nur fürs Training)
train_augmentation = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomRotation(30),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
])

In [38]:
# Transformations-Funktion mit Augmentation
def train_transforms(batch):
    augmented_images = [train_augmentation(img) for img in batch['image']]
    inputs = processor(images=augmented_images, return_tensors="pt")
    inputs['labels'] = batch['label']
    return inputs

# Transformations-Funktion ohne Augmentation (für Val/Test)
def transforms(batch):
    inputs = processor(batch['image'], return_tensors="pt")
    inputs['labels'] = batch['label']
    return inputs

In [39]:
# Transforms anwenden
processed_dataset = our_dataset.with_transform(transforms)
processed_dataset['train'] = processed_dataset['train'].with_transform(train_transforms)

In [None]:
processed_dataset

### Data collation

In [41]:
# Erstellt die Tensor-Batches für das Training
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

## Calculate the Metrics
---

In [42]:
# Metriken für die Berechnung der Genauigkeit definieren
accuracy = evaluate.load('accuracy')
def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits,axis=1)
    score = accuracy.compute(predictions=predictions, references=labels)
    return score

## Loading the Model
---

In [43]:
# Mapping von IDs zu Tiernamen
label_names = our_dataset['train'].features['label'].names

id2label = {i: label for i, label in enumerate(label_names)}
label2id = {label: i for i, label in enumerate(label_names)}

In [None]:
from transformers import ViTForImageClassification

# Pretrained ViT laden und für 90 Klassen anpassen
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224',
    num_labels=len(label_names),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)

In [None]:
model

#### Freeze the parameters

In [46]:
# Nur Klassifikationskopf trainierbar machen (Rest einfrieren)
for name,p in model.named_parameters():
    if not name.startswith('classifier'):
        p.requires_grad = False

#### Check the number of total parameters and of trainable ones

In [None]:
# Anzahl Parameter ausgeben
num_params = sum([p.numel() for p in model.parameters()])
trainable_params = sum([p.numel() for p in model.parameters() if p.requires_grad])

print(f"{num_params = :,} | {trainable_params = :,}")

## Training
---

In [None]:
# Trainingskonfiguration
training_args = TrainingArguments(
    output_dir="./vit-90-animals",
    per_device_train_batch_size=16,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_steps=100,
    num_train_epochs=5,
    learning_rate=3e-4,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=True,
    report_to='tensorboard',
    load_best_model_at_end=True,
    run_name="vit-90-animals-transferlearning"
)

In [None]:
# Trainer vorbereiten
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=processed_dataset["train"],
    eval_dataset=processed_dataset["validation"],
    tokenizer=processor
)

In [None]:
# Training ausführen
trainer.train()

#### Evaluating on the test dataset

In [None]:
# Trainiertes Modell auf Testdaten evaluieren
trainer.evaluate(processed_dataset['test'])

#### Some predictions made by the trained model

In [None]:
%matplotlib inline

# Testbilder mit vorhergesagtem und tatsächlichem Label anzeigen
def show_predictions(rows, cols):
    samples = our_dataset['test'].shuffle(seed=42).select(np.arange(rows * cols))

    processed_samples = samples.with_transform(transforms)

    predictions = trainer.predict(processed_samples).predictions.argmax(axis=1)

    fig = plt.figure(figsize=(cols * 4, rows * 4))
    label_names = our_dataset['train'].features['label'].names

    for i in range(rows * cols):
        img = samples[i]['image']
        true_label = label_names[samples[i]['label']]
        pred_label = label_names[predictions[i]]
        label = f"label: {true_label}\npredicted: {pred_label}"

        fig.add_subplot(rows, cols, i + 1)
        plt.imshow(img)
        plt.title(label)
        plt.axis('off')

show_predictions(rows=5,cols=5)


## Save and push Model
---

In [None]:
# Metadaten für den Upload definieren
kwargs = {
    "finetuned_from": model.config._name_or_path,
    "dataset": "iamsouravbanerjee/animal-image-dataset-90-different-animals",
    "tasks": "image-classification",
    "tags": ["image-classification", "animals", "vision-transformer", "vit", "transfer-learning"],
}

In [None]:
# Modell speichern und auf Hugging Face Hub hochladen
trainer.save_model()
trainer.push_to_hub("vit-90-animals-da", **kwargs)