In [1]:
import os
import torch
import numpy as np
import requests
from PIL import Image
from datasets import load_dataset
from torchvision.transforms import (
    Compose,
    Resize,
    CenterCrop,
    ToTensor,
    Normalize,
)
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    TrainingArguments,
    Trainer,
)
from peft import PeftModel, LoraConfig, get_peft_model
import evaluate

In [2]:
MODEL_CHECKPOINT = "google/vit-base-patch16-224-in21k"
DATASET_NAME = "food101"
MODEL_SAVE_PATH = "./lora-food-model"
EPOCHS = 3
BATCH_SIZE = 128

In [3]:
def print_model_size(path: str):
    """Calculate and print model size on disk"""
    size = sum(os.path.getsize(f) for f in os.scandir(path) if f.is_file())
    print(f"Model size: {size / 1e6:.2f} MB")

In [4]:
def print_trainable_parameters(model: torch.nn.Module):
    """Print percentage of trainable parameters"""
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable params: {trainable:,}/{total:,} ({100*trainable/total:.2f}%)")

The Food101 dataset contains labels (food categories) represented as integers. However, for interpretability, it’s often useful to map these integers to human-readable names.

label2id: A dictionary that maps each label name (e.g., "pizza", "burger") to its corresponding integer ID.

id2label: A dictionary that maps each integer ID back to its label name.

In [5]:
# Dataset Preparation
def load_and_prepare_dataset():
    """
    1. Load Food101 dataset (10,000 samples)
    2. Create train/test split (90/10)
    3. Create label mappings
    """
    dataset = load_dataset(DATASET_NAME, split="train[:10000]")
    dataset = dataset.train_test_split(test_size=0.1)
    
    # Create label mappings
    label2id = {label: i for i, label in enumerate(dataset["train"].features["label"].names)}
    id2label = {i: label for i, label in enumerate(dataset["train"].features["label"].names)}
    
    return dataset["train"], dataset["test"], label2id, id2label

Initialize Image Processor:
* Loads a preconfigured AutoImageProcessor for the specified Vision Transformer (ViT) model (e.g., "google/vit-base-patch16-224").
* The processor contains model-specific parameters like:
* Expected input size (processor.size["height"] and processor.size["width"]).
* Normalization statistics (image_mean, image_std) used during the model’s original training.

Define Transformation Pipeline:
* Resize(processor.size["height"]): Resizes images to match the model’s expected input dimensions (e.g., 224x224 for ViT-Base).
* CenterCrop(processor.size["height"]): Crops the center of the resized image to ensure a fixed size (e.g., 224x224) even if the original aspect ratio is different.
* ToTensor(): Converts the image from PIL format to a PyTorch tensor (with shape [C, H, W] and values in [0, 1]).
* Normalize(...): Normalizes the tensor using the model’s pretraining statistic

Batch Processing Logic:
* batch["image"]: Assumes the input batch contains a list of PIL images under the key "image".
* img.convert("RGB"): Ensures all images are in RGB format (3 channels), even if some are grayscale or have alpha channels.
* preprocess_pipeline(img): Applies the transformation pipeline (resize, crop, tensor conversion, normalization) to each image.
* batch["pixel_values"]: Stores the processed tensors under the key "pixel_values", which is the expected input format for ViT models in Hugging Face.

In [6]:
# Image Preprocessing
def create_preprocessing_pipeline():
    """
    Create transformation pipeline matching ViT requirements:
    1. Resize to 224x224
    2. Center crop
    3. Convert to tensor
    4. Normalize with ImageNet stats
    """
    processor = AutoImageProcessor.from_pretrained(MODEL_CHECKPOINT)
    return Compose([
        Resize(processor.size["height"]),
        CenterCrop(processor.size["height"]),
        ToTensor(),
        Normalize(mean=processor.image_mean, std=processor.image_std),
    ])

def build_preprocess_batch(preprocess_pipeline):
    """Build a function that applies preprocessing to batch of images"""
    def preprocess_batch(batch):
        batch["pixel_values"] = [preprocess_pipeline(img.convert("RGB")) 
                                for img in batch["image"]]
        return batch
    return preprocess_batch

The function sets up the pre-trained Vision Transformer (ViT) model with Low-Rank Adaptation (LoRA) for fine-tuning.
It takes label mapping dictionaries (label2id and id2label) as input

AutoModelForImageClassification.from_pretrained
* loads a pre-trained ViT model
* label2id and id2label: Mappings between class names and integer IDs (required for the classification head).
* ignore_mismatched_sizes=True allows the model to resize its classification head if the number of classes in the pre-trained model differs from the target dataset
  
LoraConfig defines how LoRA adapters are applied to the model.
LoRA adaptation parameters:
* r=16: Sets the rank of the low-rank matrices to 16 (higher values = more capacity but more parameters)
* lora_alpha=16: Sets the scaling factor for the LoRA updates (typically set equal to r)
* target_modules=["query", "value"]: Specifies which attention matrices to adapt (only modifying the query and value matrices in the self-attention mechanism, not the key matrices)
* lora_dropout=0.1: Adds 10% dropout to LoRA layers for regularization during training
* bias="none": Doesn't apply LoRA to bias parameters
* modules_to_save=["classifier"]: Ensures the classifier layer (the final classification head) remains fully trainable, not frozen or replaced by LoRA.

get_peft_model injects LoRA into the base_model based on lora_config.

print_trainable_parameters: Outputs the percentage of trainable parameters (e.g., 0.1% instead of 100% for full fine-tuning).

In [7]:
# Model Setup
def initialize_model(label2id, id2label):
    """
    Initialize model with LoRA adaptation:
    1. Load pre-trained ViT
    2. Add LoRA to query and value layers
    3. Keep classifier layer trainable
    """
    # Load base model
    base_model = AutoModelForImageClassification.from_pretrained(
        MODEL_CHECKPOINT,
        label2id=label2id,
        id2label=id2label,
        ignore_mismatched_sizes=True,
    )
    
    # Configure LoRA
    lora_config = LoraConfig(
        # LoRA rank
        r=16,                  
        # Scaling factor
        lora_alpha=16,        
        # Only modify specific attention matrices
        target_modules=["query", "value"],  
        # Regularization
        lora_dropout=0.1,     
        # No bias params
        bias="none",          
        # Full fine-tune the classifier head
        modules_to_save=["classifier"],  
    )
    
    # Create LoRA model
    lora_model = get_peft_model(base_model, lora_config)
    print_trainable_parameters(lora_model)
    return lora_model

configure_training_parameters sets up all the configuration needed for efficiently training a Vision Transformer model with LoRA.

TrainingArguments
* output_dir="./checkpoints": Specifies where to save model checkpoints during training
* per_device_train_batch_size=BATCH_SIZE: Sets the batch size for training, using a predefined constant
* per_device_eval_batch_size=BATCH_SIZE: Sets the batch size for evaluation, using the same value
* gradient_accumulation_steps=4: Accumulates gradients over 4 batches before updating weights, effectively increasing the batch size by 4x without increasing memory usage
* fp16=True: Enables mixed precision training (using 16-bit floating-point numbers where possible), which speeds up training and reduces memory usage
* learning_rate=5e-3: Sets a relatively high learning rate (0.005), which is appropriate for LoRA fine-tuning since fewer parameters are being updated
* num_train_epochs=EPOCHS: Sets the number of full passes through the training data
* logging_steps=10: Logs training metrics every 10 training steps
* evaluation_strategy="epoch": Evaluates the model on the test dataset after each epoch
* save_strategy="epoch": Saves a checkpoint after each epoch
* load_best_model_at_end=True: Loads the best-performing model (by evaluation metrics) at the end of training

compute_metrics:
* Takes model predictions and reference labels as input
* Converts the raw prediction logits to class predictions by taking the argmax along the class dimension
* Computes and returns the accuracy by comparing predictions to reference labels

collate_fn:
* Takes a batch of examples (each containing processed image tensors and labels)
* Stacks the "pixel_values" from each example into a single tensor batch
* Converts the labels into a tensor
* Returns a dictionary with the batched tensors

In [22]:
# Training Setup
def configure_training_parameters(train_dataset, test_dataset):
    """
    Configure training parameters:
    - Mixed precision training
    - Gradient accumulation
    - Batch processing
    - Model evaluation
    """        
    args = TrainingArguments(
        # Output directory
        output_dir="./checkpoints",
        # Batch size
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        # Gradient accumulation
        gradient_accumulation_steps=4,
        # Mixed precision training
        fp16=use_fp16,
        # Learning rate
        learning_rate=5e-3,
        # Number of training epochs
        num_train_epochs=EPOCHS,
        # Logging steps
        logging_steps=10,
        # Evaluation strategy
        evaluation_strategy="epoch",
        # Save strategy
        save_strategy="epoch",
        # Load best model at end
        load_best_model_at_end=True,
    )
    
    # Accuracy metric
    metric = evaluate.load("accuracy")
    
    def compute_metrics(eval_pred):
        # Get predictions
        predictions = np.argmax(eval_pred.predictions, axis=1)
        # Compute accuracy
        return metric.compute(predictions=predictions, references=eval_pred.label_ids)
    
    # Data collator for batching
    def collate_fn(examples):
        return {
            "pixel_values": torch.stack([e["pixel_values"] for e in examples]),
            "labels": torch.tensor([e["label"] for e in examples]),
        }
    
    return args, compute_metrics, collate_fn

predict function handles the entire inference process for classifying food images with the LoRA-adapted Vision Transformer model

Steps:
* Loads the image processor from the original model checkpoint to ensure preprocessing is consistent with what the model expects.
* Loads the base Vision Transformer model, with ignore_mismatched_sizes=True to handle any potential mismatches between the original classification head and the fine-tuned one.
* Loads and applies the trained LoRA adapter weights from MODEL_SAVE_PATH to the base model using the PeftModel.from_pretrained() method.
* Opens the image file from the provided path using PIL's Image.open. Converts the image to RGB format. Returns the processed inputs in PyTorch tensor format with batch dimension
* Passes the processed inputs to the model and collects the outputs. outputs.logits: Contains unnormalized prediction scores (logits) for all food classes.
* Finds the index of the highest logit value using argmax(), which represents the predicted class ID.
* Converts this numeric ID back to a human-readable food class label using the model's id2label mapping

In [9]:
# Inference/Prediction Function
def predict(image_path: str):
    """
    Predict food class from image:
    1. Load trained LoRA adapter
    2. Preprocess image
    3. Run inference
    """
    # Load processor and base model
    processor = AutoImageProcessor.from_pretrained(MODEL_CHECKPOINT)
    base_model = AutoModelForImageClassification.from_pretrained(
        MODEL_CHECKPOINT,
        ignore_mismatched_sizes=True,
    )
    
    # Load LoRA adapter
    model = PeftModel.from_pretrained(base_model, MODEL_SAVE_PATH)
    
    # Preprocess image
    image = Image.open(image_path)
    inputs = processor(image.convert("RGB"), return_tensors="pt")
    
    # Run inference
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get predicted class
    predicted_id = outputs.logits.argmax().item()
    return model.config.id2label[predicted_id]

In [18]:
# Load and prepare data
train_dataset, test_dataset, label2id, id2label = load_and_prepare_dataset()

print(type(train_dataset))
print(type(label2id))

<class 'datasets.arrow_dataset.Dataset'>
<class 'dict'>


In [19]:
# Create preprocessing pipeline
preprocess_pipeline = create_preprocessing_pipeline()

# Apply preprocessing
batch_preprocess_fn = build_preprocess_batch(preprocess_pipeline)
train_dataset.set_transform(batch_preprocess_fn)
test_dataset.set_transform(batch_preprocess_fn)

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


In [None]:
# Initialize model
model = initialize_model(label2id, id2label)

# Set up training
args, compute_metrics, collate_fn = configure_training_parameters(train_dataset, test_dataset)

# Create Trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)

In [None]:
# Start training
print("Starting training...")
trainer.train()

In [None]:
# Evaluate and save
results = trainer.evaluate()
print(f"Final accuracy: {results['eval_accuracy']:.2%}")
trainer.save_model(MODEL_SAVE_PATH)
print_model_size(MODEL_SAVE_PATH)

In [None]:
# Predict
test_image = "path_to_your_image.jpg"
print(f"Prediction: {predict(test_image)}")