# Transfer Learning on google/vit-base-patch16-224
- From HuggingFace

In [None]:
!pip3 install --quiet evaluate transformers # For training and evaluation

In [9]:
import io
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch import nn

from huggingface_hub import notebook_login
from datasets import load_dataset
from transformers import AutoImageProcessor, ViTForImageClassification

from transformers import TrainingArguments, Trainer
import evaluate

## Load Dataset

In [None]:
notebook_login()

In [None]:
dataset = load_dataset("AaronLoera/Qr_Classifier") # Whenever we upload to hub
dataset

labels = ...#

## Preprocessing

In [None]:
# Load image processor
""" 
Prepares input features for vision models and post processing their outputs by 
    transformations such as resizing, normalization, and conversion to PyTorch 
    and Numpy tensors.
"""
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
processor

label2id = ... # Fill in with dataset

In [None]:
"""
This function prepares a batch of raw image data and labels so they can be fed into a model.
"""
def transforms(batch):
    batch['image'] = [Image.open(io.BytesIO(x['bytes'])).convert('RGB') for x in batch['image']]
    inputs = processor(batch['image'],return_tensors='pt')
    inputs['labels']=[label2id[y] for y in batch['label']]
    return inputs

"""
This function takes a list of inputs dictionaries (from transforms) and stacks 
    them into batch tensors.
"""
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

## Evaluation Metrics

In [None]:
"""
Compute evaluation metrics for model predictions using the Hugging Face
    `evaluate` library.
"""
accuracy = evaluate.load('accuracy')
def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits,axis=1)
    score = accuracy.compute(predictions=predictions, references=labels)
    return score

## Model

In [None]:
# Model initialization
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels = 2,
    id2label = ...,
    label2id = ..., 
    ignore_mismatched_sizes=True
)   

### Freezing layers

In [None]:
# Freeze 85,827,109 params
for name,p in model.named_parameters():
    if not name.startswith('classifier'):
        p.requires_grad = False

num_params = sum([p.numel() for p in model.parameters()])
trainable_params = sum([p.numel() for p in model.parameters() if p.requires_grad])

print(f"{num_params = :,} | {trainable_params = :,}")

### Training

In [None]:
# Define training arguments and train
training_args = TrainingArguments(
    output_dir="./vit-base-qr-classifier",
    per_device_train_batch_size=16,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_steps=100,
    num_train_epochs=5,
    learning_rate=.0001,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=True,
    report_to='none',
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=...,
    eval_dataset=...,
    tokenizer=processor
)

trainer.train()

## Model Evaluation

In [None]:
trainer.evaluate(...)