In [1]:
# Importing libraries
from tqdm import tqdm
import torch
from sklearn.preprocessing import LabelEncoder
from transformers import (
    AutoModelForImageClassification,
    ViTFeatureExtractor,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer,
)
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
import evaluate
import numpy as np
import warnings
import os

from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
import pandas as pd
import json

# Importing the arg parser
from utils import parse_args, gather_metrics, plot_metrics

# Set warnings to ignore to keep output clean
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Hyperparameters
dataset_type = "stanford-dogs"
train_num_classes = None
use_lora = True
lora_rank = 500
lora_alpha, lora_dropout = 2*lora_rank, 0.1
batch_size, eval_batch_size, gradient_accumulation_steps, max_steps = 32, 50, 1, 500

In [3]:
# Dataset selection
if dataset_type == "mnist":
    dataset = load_dataset("mnist")
    num_classes = 10
    label_column_name = 'label'
    image_col_name = "image"

elif dataset_type == "oxford-pet":
    dataset = load_dataset("visual-layer/oxford-iiit-pet-vl-enriched")
    num_classes = 37
    label_column_name = 'label_breed'
    image_col_name = "image"

elif dataset_type == "stanford-dogs":
    dataset = load_dataset("amaye15/stanford-dogs")
    num_classes = 120
    label_column_name = 'label'
    image_col_name = "pixel_values"

else:
    raise ValueError("Currently not supported -> You can add them now")

# Creating val/train split
dataset = dataset["train"].train_test_split(test_size=0.15, shuffle=True, seed=1)
train_dataset = dataset['train']
val_dataset = dataset['test']

In [4]:
# Preprocessing for the labels -> Only necessary for oxford-pet, not mnist or stanford-dogs
if dataset_type == "oxford-pet":
    label_encoder = LabelEncoder()

    def label_preprocessing(dataset):
        # Fit the encoder on the string labels and transform them to integer labels
        label_encoder.fit(dataset[label_column_name])
        encoded_labels = label_encoder.transform(dataset[label_column_name])

        # Add the encoded labels as a new column in the dataset
        return dataset.add_column('label', encoded_labels)

    # Apply preprocessing
    train_dataset = label_preprocessing(train_dataset)
    val_dataset = label_preprocessing(val_dataset)

In [5]:

# Filter classes if specified
if train_num_classes:
    selected_classes = train_num_classes

    def filter_classes(batch):
        return batch['label'] in selected_classes

    train_dataset = train_dataset.filter(filter_classes)
    val_dataset = val_dataset.filter(filter_classes)

    # Update num_classes to reflect the number of selected classes
    num_classes = len(selected_classes)

    # Preprocessing for the labels -> Once filtered the labels need to be set between (0, num(classes)-1)
    label_encoder = LabelEncoder()

    def label_preprocessing(dataset):
        # Fit the encoder only on the filtered labels
        label_encoder.fit(selected_classes)
        # Transform the dataset labels
        dataset = dataset.map(lambda batch: {'label': label_encoder.transform([batch['label']])[0]})
        return dataset

    # Apply preprocessing
    train_dataset = label_preprocessing(train_dataset)
    val_dataset = label_preprocessing(val_dataset)

else:
    selected_classes = [i for i in range(num_classes)]

# Preprocessing dataset to be compatible with ViT
transform = Compose([
    Resize((224, 224)),
    ToTensor()
])

# Combined function to resize, convert to RGB, and then to tensor
def preprocess_images(batch):
    batch['pixel_values'] = [transform(image.convert("RGB")) for image in batch[image_col_name]]
    if image_col_name!='pixel_values':
        del batch[image_col_name]
    return batch

# Apply resizing
train_dataset = train_dataset.map(preprocess_images, batched=True)
val_dataset = val_dataset.map(preprocess_images, batched=True)

In [6]:
# Setup LoRA
if use_lora:
    layers = ["query", "key", "value"]
    target_modules = [f"vit.encoder.layer.{i}.attention.attention.{layer}" for i in range(0, 12) for layer in layers]


    # Set up LoRA configuration
    lora_config = LoraConfig(
        r=lora_rank,
        lora_alpha=lora_alpha,
        target_modules=target_modules,
        lora_dropout=lora_dropout,
        use_rslora=True,
    )

# Load model and tokenizer
model_name = "google/vit-base-patch16-224" 
model = AutoModelForImageClassification.from_pretrained(model_name, num_labels = num_classes, ignore_mismatched_sizes=True)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)

# Apply LoRA to the model
if use_lora:
    model = get_peft_model(model, lora_config)

# Move model to GPU
model = model.to("cuda")

# Define accuracy metric
accuracy = evaluate.load("accuracy")

# Define the compute_metrics function to calculate accuracy
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = logits.argmax(axis=-1)
    return accuracy.compute(predictions=predictions, references=labels)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([120]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([120, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:

# Init run
run_name = "ViT"
training_args = TrainingArguments(
    output_dir=f"results/{run_name}",
    per_device_train_batch_size= batch_size,
    per_device_eval_batch_size= eval_batch_size,
    gradient_accumulation_steps= gradient_accumulation_steps,
    max_steps=max_steps,
    logging_steps=10,
    eval_steps=10,
    save_steps=10,
    save_total_limit=1,
    evaluation_strategy="steps"
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

# Start training
trainer.train()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
max_steps is given, it will override any value given in num_train_epochs


Step,Training Loss,Validation Loss
10,4.7839,No log
20,4.3052,No log
30,3.7672,No log
40,3.3163,No log
50,2.9274,No log
60,2.6078,No log
70,2.1418,No log
80,2.0813,No log
90,1.8459,No log
100,1.8234,No log


KeyboardInterrupt: 

In [None]:
# Gather data from trainer and plot
data = gather_metrics(trainer)
plot_metrics(data)