# Fine Tuning an Vision Model with LoRA(Low-Rank Adaptation)

In [1]:

%pip install transformers accelerate evaluate datasets peft -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m48.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m22.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m49.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
model = "google/vit-base-patch16-224-in21k"

In [3]:
import os
import torch
from peft import PeftModel, LoraConfig, get_peft_model
from transformers import AutoModelForImageClassification

# to print model size from disk
def print_model_size(path):
    size = 0
    for f in os.scandir(path):
        size += os.path.getsize(f)

    print(f"Model size: {(size / 1e6):.2} MB")

# function for print the trainable paramaters
def print_trainable_parameters(model, label):
    parameters, trainable = 0, 0

    for _, p in model.named_parameters():
        parameters += p.numel()
        trainable += p.numel() if p.requires_grad else 0

    print(f"{label} trainable parameters: {trainable:,}/{parameters:,} ({100 * trainable / parameters:.2f}%)")

# function to split the dataset
def split_dataset(dataset):
    dataset_splits = dataset.train_test_split(test_size=0.1)
    return dataset_splits.values()

#creating the helper fucnction for label_mapping
def create_label_mappings(dataset):
    label2id, id2label = dict(), dict()
    for i, label in enumerate(dataset.features["label"].names):
        label2id[label] = i
        id2label[i] = label

    return label2id, id2label

In [6]:
from datasets import load_dataset, load_from_disk, Dataset
import shutil
import os

# Delete Huggingface cache (use with caution, will delete all cached datasets)
shutil.rmtree(os.path.expanduser("~/.cache/huggingface/datasets"), ignore_errors=True)

## Loading and Preparing the Datasets

In [14]:
from datasets import load_dataset

# Food101: Load the full training split
dataset1_full = load_dataset("food101", split="train", download_mode="force_redownload")

# Select the first 10,000 examples
dataset1 = dataset1_full.select(range(10000))

# Cats vs Dogs: Load the full training split
dataset2_full = load_dataset("microsoft/cats_vs_dogs", split="train", trust_remote_code=True, download_mode="force_redownload")

# Select the first 10,000 examples (adjust if needed)
dataset2 = dataset2_full.select(range(10000))

# Rename for consistency
dataset2 = dataset2.rename_column("labels", "label")

# Now you can use your split_dataset function as before
dataset1_train, dataset1_test = split_dataset(dataset1)
dataset2_train, dataset2_test = split_dataset(dataset2)

Downloading readme:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]




Downloading data:   0%|          | 0.00/490M [00:00<?, ?B/s][A[A[A


Downloading data:   2%|▏         | 10.5M/490M [00:00<00:24, 19.9MB/s][A[A[A


Downloading data:   9%|▊         | 41.9M/490M [00:00<00:05, 82.6MB/s][A[A[A


Downloading data:  17%|█▋        | 83.9M/490M [00:00<00:02, 150MB/s] [A[A[A


Downloading data:  24%|██▎       | 115M/490M [00:00<00:02, 185MB/s] [A[A[A


Downloading data:  30%|██▉       | 147M/490M [00:01<00:01, 193MB/s][A[A[A


Downloading data:  36%|███▋      | 178M/490M [00:01<00:01, 204MB/s][A[A[A


Downloading data:  43%|████▎     | 210M/490M [00:01<00:01, 209MB/s][A[A[A


Downloading data:  49%|████▉     | 241M/490M [00:01<00:01, 209MB/s][A[A[A


Downloading data:  56%|█████▌    | 273M/490M [00:01<00:01, 203MB/s][A[A[A


Downloading data:  62%|██████▏   | 304M/490M [00:01<00:00, 208MB/s][A[A[A


Downloading data:  69%|██████▊   | 336M/490M [00:01<00:00, 219MB/s][A[A[A


Downloading data:  75%|███████▍  | 367M/490M [00:

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/75750 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/25250 [00:00<?, ? examples/s]

NotImplementedError: Loading a dataset cached in a LocalFileSystem is not supported.

In [12]:
!pip install --upgrade datasets



In [15]:
from datasets import load_dataset

# Use a safe cache dir
cache_dir = "/tmp/hf-datasets-cache"

# Food101: Use streaming to avoid cache problems
dataset1_stream = load_dataset("food101", split="train", streaming=True)
dataset1 = [x for _, x in zip(range(10000), dataset1_stream)]

# Cats vs Dogs: Use streaming as well
dataset2_stream = load_dataset("microsoft/cats_vs_dogs", split="train", streaming=True, trust_remote_code=True)
dataset2 = [x for _, x in zip(range(10000), dataset2_stream)]

# If you need them as Hugging Face Datasets, convert lists to Dataset objects
from datasets import Dataset
dataset1 = Dataset.from_list(dataset1)
dataset2 = Dataset.from_list(dataset2)
dataset2 = dataset2.rename_column("labels", "label")

# Now you can split as usual
dataset1_train, dataset1_test = split_dataset(dataset1)
dataset2_train, dataset2_test = split_dataset(dataset2)

NotImplementedError: Loading a streaming dataset cached in a LocalFileSystem is not supported yet.

In [16]:
# Download and extract
!wget --no-check-certificate https://data.vision.ee.ethz.ch/cvl/food-101.tar.gz
!tar -xzf food-101.tar.gz

--2025-06-09 14:49:21--  https://data.vision.ee.ethz.ch/cvl/food-101.tar.gz
Resolving data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)... 129.132.52.178, 2001:67c:10ec:36c2::178
Connecting to data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)|129.132.52.178|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4996278331 (4.7G) [application/x-gzip]
Saving to: ‘food-101.tar.gz’


2025-06-09 14:52:58 (22.0 MB/s) - ‘food-101.tar.gz’ saved [4996278331/4996278331]



In [18]:
# 1. Install dependencies
!pip install datasets pillow tqdm --quiet

In [62]:


# 2. Download and prepare Food-101 (images and labels)
import os
from PIL import Image
from datasets import Dataset
from tqdm import tqdm

# Download Food-101
if not os.path.exists("food-101"):
    !wget --no-check-certificate https://data.vision.ee.ethz.ch/cvl/food-101.tar.gz
    !tar -xzf food-101.tar.gz

img_dir = "food-101/images"
categories = sorted(os.listdir(img_dir))
print("Sample categories:", categories[:5])

samples = []
for cat in tqdm(categories[:2], desc='Building Food-101 sample (2 classes)'):  # Just 2 classes for speed; use more if desired
    cat_dir = os.path.join(img_dir, cat)
    files = sorted(os.listdir(cat_dir))[:100]  # 100 images per class for speed
    for img_file in files:
        samples.append({"image": os.path.join(cat_dir, img_file), "label": cat})

dataset1 = Dataset.from_list(samples)
print("Food-101 example:", dataset1[0])

# 3. Download and prepare Cats vs Dogs using TensorFlow Datasets
import tensorflow_datasets as tfds

catsdogs = tfds.load('cats_vs_dogs', split='train', as_supervised=True)
catsdogs_list = []
for img, label in tfds.as_numpy(catsdogs.take(200)):  # 200 samples for speed
    # Save images to disk so they can be referenced in Hugging Face Dataset
    save_dir = 'cats_vs_dogs_imgs'
    os.makedirs(save_dir, exist_ok=True)
    label_name = 'cat' if label == 0 else 'dog'
    img_path = os.path.join(save_dir, f"{label_name}_{len(catsdogs_list)}.jpg")
    Image.fromarray(img).save(img_path)
    catsdogs_list.append({"image": img_path, "label": label_name})

dataset2 = Dataset.from_list(catsdogs_list)
print("Cats vs Dogs example:", dataset2[0])

# 4. Simple split function (80/20 train/test)
def split_dataset(ds, test_size=0.2):
    shuffled = ds.shuffle(seed=42)
    n_test = int(len(ds) * test_size)
    return shuffled.select(range(len(ds)-n_test)), shuffled.select(range(len(ds)-n_test, len(ds)))

dataset1_train, dataset1_test = split_dataset(dataset1)
dataset2_train, dataset2_test = split_dataset(dataset2)

print("Food-101 train/test:", len(dataset1_train), len(dataset1_test))
print("Cats vs Dogs train/test:", len(dataset2_train), len(dataset2_test))

Sample categories: ['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare']


Building Food-101 sample (2 classes): 100%|██████████| 2/2 [00:00<00:00, 572.21it/s]

Food-101 example: {'image': 'food-101/images/apple_pie/1005649.jpg', 'label': 'apple_pie'}





Cats vs Dogs example: {'image': 'cats_vs_dogs_imgs/dog_0.jpg', 'label': 'dog'}
Food-101 train/test: 160 40
Cats vs Dogs train/test: 160 40


In [63]:
def create_label_mappings(dataset):
    unique_labels = sorted(list(set(dataset["label"])))
    label2id = {label: i for i, label in enumerate(unique_labels)}
    id2label = {i: label for label, i in label2id.items()}
    return label2id, id2label

dataset1_label2id, dataset1_id2label = create_label_mappings(dataset1)
dataset2_label2id, dataset2_id2label = create_label_mappings(dataset2)

In [64]:

config = {
    "model1": {
        "train_data": dataset1_train,
        "test_data": dataset1_test,
        "label2id": dataset1_label2id,
        "id2label": dataset1_id2label,
        "epochs": 5,
        "path": "./lora-model1"
    },
    "model2": {
        "train_data": dataset2_train,
        "test_data": dataset2_test,
        "label2id": dataset2_label2id,
        "id2label": dataset2_id2label,
        "epochs": 1,
        "path": "./lora-model2"
    },
}

In [65]:
from transformers import AutoImageProcessor
from PIL import Image

image_processor = AutoImageProcessor.from_pretrained(model, use_fast=True)


from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    Resize,
    ToTensor,
)

preprocess_pipeline = Compose([
    Resize(image_processor.size["height"]),
    CenterCrop(image_processor.size["height"]),
    ToTensor(),
    Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
])

def make_preprocess(label2id, preprocess_pipeline):
    from PIL import Image
    def preprocess(batch):
        batch["pixel_values"] = [
            preprocess_pipeline(Image.open(image_path).convert("RGB")) for image_path in batch["image"]
        ]
        batch["label"] = [label2id[label] for label in batch["label"]]
        return batch
    return preprocess

# Now set the transform for each dataset with the correct label2id
for cfg in config.values():
    cfg["train_data"].set_transform(make_preprocess(cfg["label2id"], preprocess_pipeline))
    cfg["test_data"].set_transform(make_preprocess(cfg["label2id"], preprocess_pipeline))

# Fine-Tuning the Model

In [66]:
# Fine-Tuning the Model

import numpy as np
import evaluate
import torch
from peft import PeftModel, LoraConfig, get_peft_model
from transformers import AutoModelForImageClassification


metric = evaluate.load("accuracy")


def data_collate(examples):
    import torch
    # Check first label type for debugging
    if isinstance(examples[0]["label"], str):
        raise ValueError(
            f"Label is a string: {examples[0]['label']}. "
            "Convert string labels to integer IDs before batching!"
        )
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}


def compute_metrics(eval_pred):
    """
    Compute the model's accuracy on a batch of predictions.
    """
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)


def get_base_model(label2id, id2label):
    """
    Create an image classification base model from
    the model checkpoint.
    """
    return AutoModelForImageClassification.from_pretrained(
        model,
        label2id=label2id,
        id2label=id2label,
        ignore_mismatched_sizes=True,
    )


def build_lora_model(label2id, id2label):
    """Build the LoRA model to fine-tune the base model."""
    model = get_base_model(label2id, id2label)
    print_trainable_parameters(model, label="Base model")

    config = LoraConfig(
        r=16,
        lora_alpha=16,
        target_modules=["query", "value"],
        lora_dropout=0.1,
        bias="none",
        modules_to_save=["classifier"],
    )

    lora_model = get_peft_model(model, config)
    print_trainable_parameters(lora_model, label="LoRA")

    return lora_model

### Let's now configure the fine-tuning process.

In [67]:
from transformers import TrainingArguments

batch_size = 128
training_arguments = TrainingArguments(
    output_dir="./model-checkpoints",
    remove_unused_columns=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-3,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=4,
    fp16=True,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    label_names=["labels"],
)

In [68]:

from transformers import Trainer

for cfg in config.values():
    training_arguments.num_train_epochs = cfg["epochs"]

    trainer = Trainer(
        build_lora_model(cfg["label2id"], cfg["id2label"]),
        training_arguments,
        train_dataset=cfg["train_data"],
        eval_dataset=cfg["test_data"],
        tokenizer=image_processor,
        compute_metrics=compute_metrics,
        data_collator=data_collate,
    )

    results = trainer.train()
    evaluation_results = trainer.evaluate(cfg['test_data'])
    print(f"Evaluation accuracy: {evaluation_results['eval_accuracy']}")

    # We can now save the fine-tuned model to disk.
    trainer.save_model(cfg["path"])
    print_model_size(cfg["path"])

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Base model trainable parameters: 85,800,194/85,800,194 (100.00%)
LoRA trainable parameters: 591,362/86,391,556 (0.68%)


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.52207,0.975
2,No log,0.370843,0.975
3,No log,0.255721,0.975
4,No log,0.182797,1.0
5,No log,0.149683,1.0


Evaluation accuracy: 1.0
Model size: 2.4 MB


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Base model trainable parameters: 85,800,194/85,800,194 (100.00%)
LoRA trainable parameters: 591,362/86,391,556 (0.68%)


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.512402,1.0


Evaluation accuracy: 1.0
Model size: 2.4 MB


## Running Inference

In [69]:
def build_inference_model(label2id, id2label, lora_adapter_path):
    """Build the model that will be use to run inference."""

    # Let's load the base model
    model = get_base_model(label2id, id2label)

    # Now, we can create the inference model combining the base model
    # with the fine-tuned LoRA adapter.
    return PeftModel.from_pretrained(model, lora_adapter_path)


def predict(image, model, image_processor):
    """Predict the class represented by the supplied image."""

    encoding = image_processor(image.convert("RGB"), return_tensors="pt")
    with torch.no_grad():
        outputs = model(**encoding)
        logits = outputs.logits

    class_index = logits.argmax(-1).item()
    return model.config.id2label[class_index]

In [70]:

for cfg in config.values():
    cfg["inference_model"] = build_inference_model(cfg["label2id"], cfg["id2label"], cfg["path"])
    cfg["image_processor"] = AutoImageProcessor.from_pretrained(cfg["path"])

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [73]:

samples = [
    {
        "image": "https://www.allrecipes.com/thmb/AtViolcfVtInHgq_mRtv4tPZASQ=/1500x0/filters:no_upscale():max_bytes(150000):strip_icc()/ALR-187822-baked-chicken-wings-4x3-5c7b4624c8554f3da5aabb7d3a91a209.jpg",
        "model": "model1",
    },
    {
        "image": "https://wallpapers.com/images/featured/kitty-cat-pictures-nzlg8fu5sqx1m6qj.jpg",
        "model": "model2",
    },
    {
        "image": "https://i.natgeofe.com/n/5f35194b-af37-4f45-a14d-60925b280986/NationalGeographic_2731043_3x4.jpg",
        "model": "model2",
    },
    {
        "image": "https://www.simplyrecipes.com/thmb/KE6iMblr3R2Db6oE8HdyVsFSj2A=/1500x0/filters:no_upscale():max_bytes(150000):strip_icc()/__opt__aboutcom__coeus__resources__content_migration__simply_recipes__uploads__2019__09__easy-pepperoni-pizza-lead-3-1024x682-583b275444104ef189d693a64df625da.jpg",
        "model": "model1"
    }
]

In [74]:

from PIL import Image
import requests

for sample in samples:
    image = Image.open(requests.get(sample["image"], stream=True).raw)

    inference_model = config[sample["model"]]["inference_model"]
    image_processor = config[sample["model"]]["image_processor"]

    prediction = predict(image, inference_model, image_processor)
    print(f"Prediction: {prediction}")

Prediction: baby_back_ribs
Prediction: cat
Prediction: dog
Prediction: apple_pie
