# Transfer Learning on google/vit-base-patch16-224
- From HuggingFace

In [3]:
!pip3 install --quiet evaluate transformers

In [22]:
import io
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch import nn

from huggingface_hub import notebook_login
from datasets import Dataset
from transformers import AutoImageProcessor, ViTForImageClassification
from datasets import Image as HFImage
from PIL import Image

from transformers import TrainingArguments, Trainer
import evaluate

## Load Dataset:

In [15]:
qr_images_path = r"data/qr_images/qr_images"
qr_labels_path = r"data/qr_labels.csv"

In [28]:
def get_all_qr_images(filepath: str) -> list:
    """
    Takes in the path to the directory holding all QR images, and return the image file paths as a list
    
    Args:
        filepath: The path to the directory holding the QR images
        
    Returns:
        list: A list holding full paths to QR image files
    """
    
    # Defining sorting helper function
    def extract_number(filename: str) -> int:
        """
        Extracts the ID or number corresponding to the QR image
        
        Args:
            filename: The QR image
            
        Returns:
            int: The ID or number corresponding to the QR image
        """
        image_number = int(filename.split("_")[1].split(".")[0])
        return image_number
    
    # Extracting all items from provided path
    all_items = os.listdir(filepath)
    
    # Keeping only the files from the list
    all_files = [f for f in all_items if os.path.isfile(os.path.join(filepath, f))]
    
    # Numerically sorting the files
    sorted_files = sorted(all_files, key=extract_number)
    
    # Return full paths so downstream code can open files correctly
    full_paths = [os.path.join(filepath, f) for f in sorted_files]
    return full_paths

In [29]:
y_labels = pd.read_csv(qr_labels_path, index_col=0)
labels = y_labels['label'].tolist()
labels[:5]

[1, 1, 1, 1, 1]

In [None]:
# Hugging Face Lazy List
image_paths = get_all_qr_images(qr_images_path)

dataset = Dataset.from_dict({"image": image_paths, "label": labels})
dataset = dataset.cast_column("image", HFImage())

dataset[255:265]['label']

[1, 0, 1, 1, 0, 0, 1, 1, 1, 1]

## Preprocessing:

In [None]:
# Load image processor
""" 
Prepares input features for vision models and post processing their outputs by 
    transformations such as resizing, normalization, and conversion to PyTorch 
    and Numpy tensors.
"""
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
processor

label2id = ... # Fill in with dataset

In [None]:
def transforms(batch):
    """
    This function prepares a batch of raw image data and labels so they can be fed into a model.
    """
    batch['image'] = [Image.open(io.BytesIO(x['bytes'])).convert('RGB') for x in batch['image']]
    inputs = processor(batch['image'],return_tensors='pt')
    inputs['labels']=[label2id[y] for y in batch['label']]
    return inputs


def collate_fn(batch):
    """
    This function takes a list of inputs dictionaries (from transforms) and stacks them into batch tensors.
    """
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

## Evaluation Metrics

In [None]:
"""
Compute evaluation metrics for model predictions using the Hugging Face
    `evaluate` library.
"""
accuracy = evaluate.load('accuracy')
def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits,axis=1)
    score = accuracy.compute(predictions=predictions, references=labels)
    return score

## Model

In [None]:
# Model initialization
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels = 2,
    id2label = ...,
    label2id = ..., 
    ignore_mismatched_sizes=True
)   

### Freezing layers

In [None]:
# Freeze 85,827,109 params
for name,p in model.named_parameters():
    if not name.startswith('classifier'):
        p.requires_grad = False

num_params = sum([p.numel() for p in model.parameters()])
trainable_params = sum([p.numel() for p in model.parameters() if p.requires_grad])

print(f"{num_params = :,} | {trainable_params = :,}")

### Training

In [None]:
# Define training arguments and train
training_args = TrainingArguments(
    output_dir="./vit-base-qr-classifier",
    per_device_train_batch_size=16,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_steps=100,
    num_train_epochs=5,
    learning_rate=.0001,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=True,
    report_to='none',
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=...,
    eval_dataset=...,
    tokenizer=processor
)

trainer.train()

## Model Evaluation

In [None]:
trainer.evaluate(...)