## Vision Transformers (ViT)
- Fine-tuning
- AI Tools: Huggingface, pytorch
- Dataset: cifar10
- Reference 
  * [ViT-hugging Face](https://huggingface.co/docs/transformers/model_doc/vit)
  * [ViT-hf-ft-cifar10-pytorch ](https://github.com/supersjgk/Transformers/blob/main/VisionTransformers/Vision_Transformers_Hugging_Face_Fine_Tuning_Cifar10_PyTorch.ipynb) 
  * [vit-pytorch](https://github.com/lucidrains/vit-pytorch)

In [None]:
import numpy as np

import torch
import torchvision
from torchvision.transforms import Normalize, Resize, ToTensor, Compose
from torchvision.transforms import ToPILImage

from torchinfo import summary

from transformers import pipeline
from transformers import ViTImageProcessor, ViTForImageClassification
from transformers import TrainingArguments, Trainer
from datasets import load_dataset

from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

from PIL import Image
from IPython.display import display
import matplotlib.pyplot as plt

import requests
from datetime import datetime

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
import os
cache_path = 'd:/HF_cache'
os.environ['HF_HOME'] = cache_path
os.environ['TRANSFORMERS_CACHE'] = cache_path # seems not to work

## 1. ImageNet 1k - Image Classification 
## [Hugging Face] Inference Methods for ViT model
 크게 4가지가 있음
1. pipeline.pretrained()
2. AutoModel():
   - image_processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)
   - model = AutoModelForImageClassification.from_pretrained(model_name, device_map="auto", attn_implementation="sdpa")
3. ViTForImageClasssification
4. ViTModel, VitImageProcessor(ViTImageProcessorFast)
### Downloading pre-trained weights
- Model: ViT-B/16
  * input image size = 224
  * patch size = 16
  * Transformer-encoder configuartions
  * Classification Head
  * dropout = 0.1


In [None]:
model_name = "google/vit-base-patch16-224"

In [None]:
# load a sample image

In [None]:
def load_image_url(url):
    return Image.open(requests.get(url, stream=True).raw) # PIL format

In [None]:
# load an image
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = load_image_url(url)
image

In [None]:
image.size

In [None]:
# 1. inference using pipeline 

In [None]:
vit_pipeline = pipeline(
    task="image-classification", # task 종류 이름 
    model=model_name,
    torch_dtype=torch.float, 
    device = device,
    use_fast = True
)

start_time = datetime.now()
result = vit_pipeline(inputs='http://images.cocodataset.org/val2017/000000039769.jpg')
end_time = datetime.now()
elapsed_time = end_time - start_time
print(f'processing time: {str(elapsed_time)} sec')

In [None]:
result

In [None]:
# 2. inference using ViT HugggingFace models
# Model Card: https://huggingace.co/google/vit-base-patch16-224
# use torchvision to process fast
# - cf. ViTImageProcessorFast
processor = ViTImageProcessor.from_pretrained(model_name, device=device, use_fast=True)

mu, sigma = processor.image_mean, processor.image_std
size = processor.size
print(size, mu, sigma)

In [None]:
model = ViTForImageClassification.from_pretrained(model_name, device_map=device)
# print(model.classifier) #The google/vit-base-patch16-224 model is originally fine tuned on imagenet-1K with 1000 output classes

# print(model.config)

In [None]:
# model

In [None]:
# model.device
#inputs['pixel_values'].device
#inputs.keys()
#inputs['pixel_values'].shape

In [None]:
# using logits
start_time = datetime.now()
inputs = processor(images=image, return_tensors='pt').to(device)
# the model outputs logits
outputs = model(**inputs)
end_time = datetime.now()
elapsed_time = end_time - start_time
print(f'processing time: {str(elapsed_time)} sec')

In [None]:
# logits을 이용한 top1 class 
logits = outputs.logits

# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

In [None]:
# top-k using softmax or logits
prob_output = torch.softmax(outputs.logits[0], dim=0)
scores, indices = torch.topk(prob_output, 3)
print("Predicted class:", [(model.config.id2label[i.item()], s.item()) for (s, i) in zip(scores, indices)] )
prob_output.argmax().item(), torch.topk(prob_output, 3), torch.topk(logits, 3)

### Loading local image files to classify them

In [None]:
#image_path = "../AI-Application-Specialist-Vision/images/"
#image_path = "./images/"
image_path = ".\\images\\"
url1 = image_path + 'Granny_smith_and_cross_section.jpg'
url2 = image_path + 'Free!_(3987584939).jpg'
url = image_path + 'ILSVRC2012_val_00000466.jpg'

image1 = Image.open(url1).convert("RGB")
image2 = Image.open(url2).convert("RGB")
image = Image.open(url).convert("RGB")
image1, image2, image

In [None]:
ratio = 0.2 
image1.resize((int(image1.size[0] * ratio), int(image1.size[1] * ratio)))

In [None]:
ratio = 0.2 
image2.resize((int(image2.size[0] * ratio), int(image2.size[1] * ratio)))

In [None]:
#!dir {$image_path} /B
#image_path = ".\\images\\"
!dir {image_path} /B

In [None]:
inputs = processor(images=image1, return_tensors='pt').to(device)
# the model outputs logits
outputs = model(**inputs)
display(image1.resize((200, 200))) # you may specify interpolation methods(ex, Image.LANCZOS) for resizing
# model predicts one of the 1000 ImageNet classes
print("Predicted class:", model.config.id2label[outputs.logits.argmax(-1).item()])

In [None]:
inputs = processor(images=image2, return_tensors='pt').to(device)
# the model outputs logits
outputs = model(**inputs)
display(image2.resize((200,200)))
# model predicts one of the 1000 ImageNet classes
print("Predicted class:", model.config.id2label[outputs.logits.argmax(-1).item()])

In [None]:
inputs = processor(images=image, return_tensors='pt').to(device)
# the model outputs logits
outputs = model(**inputs)
display(image.resize((200,200)))
# model predicts one of the 1000 ImageNet classes
print("Predicted class:", model.config.id2label[outputs.logits.argmax(-1).item()])

## 2. Fine-tuning ViT using cifar10


## Dataset Preparation
### Loading the Data

In [None]:
train_dataset, test_dataset = load_dataset("cifar10", split=["train","test"])
# train dataset을 train dataset과 validation dataset으로 나눔
splits = train_dataset.train_test_split(test_size=0.1)
train_dataset = splits['train']
val_dataset = splits['test']
train_dataset, val_dataset, test_dataset

In [None]:
train_dataset.features, train_dataset.num_rows, train_dataset[0]

In [None]:
itos = dict((k,v) for k,v in enumerate(train_dataset.features['label'].names))
stoi = dict((v,k) for k,v in enumerate(train_dataset.features['label'].names))
itos, stoi

In [None]:
# sample access 
img, label = train_dataset[0]['img'], itos[train_dataset[0]['label']]
print(label)
img

In [None]:
print(img.size)

In [None]:
type(img)

### Preprocessing Data or Data Augmentation

In [None]:
model_name = "google/vit-base-patch16-224"
# use torchvision to process fast
processor = ViTImageProcessor.from_pretrained(model_name, device=device, use_fast=True)

#mu, sigma = processor.image_mean, processor.image_std
#size = processor.size

In [None]:
processor

### You may add or modify data augmentation functions

In [None]:
_transf = Compose([
    Resize(size['height']).cuda(),
    ToTensor(),
    Normalize(mean=mu, std=sigma)
])

def transf(arg):
    arg['pixels'] = [_transf(image.convert('RGB')) for image in arg['img']]
    return arg

In [None]:
train_dataset.set_transform(transf)
val_dataset.set_transform(transf)
test_dataset.set_transform(transf)

In [None]:
train_dataset[0].keys()

In [None]:
ex = train_dataset[0]['pixels']
print(ex.shape)
print(torch.min(ex), torch.max(ex))
ex = (ex+1)/2
print(torch.min(ex), torch.max(ex))

exi = ToPILImage()(ex)
plt.imshow(exi)
#plt.show()
plt.axis('off')

### Model - Fine-Tuning
- weight initialization method

In [None]:
# ImageNet 1k
# model = ViTForImageClassification.from_pretrained(model_name, device_map=device)
#The google/vit-base-patch16-224 model is originally fine tuned on imagenet-1K with 1000 output classes

# Fine-tune vit model using CIFAR10 dataset
# 10개 output class로 모델 생성 
ft_model = ViTForImageClassification.from_pretrained(model_name, num_labels=10,  ignore_mismatched_sizes=True, 
                                                     id2label=itos, label2id=stoi).to(device)

In [None]:
ft_model

In [None]:
# print(ft_model.classifier)

In [None]:
ft_model.device, ft_model.config

In [None]:
summary(ft_model, input_size=(1, 3, 224, 224))

### Hugging Face Trainer

In [None]:
# training hyperparameters
batch_size = 32
num_train_epochs = 5 # 10

In [None]:
max_steps_per_epoch = train_dataset.num_rows//batch_size
steps_per_epoch = 200 
# steps_per_epoch = max_steps_per_epoch 
steps_per_epoch, max_steps_per_epoch

### [trainer-callbacks](https://huggingface.co/docs/transformers/v4.52.3/en/trainer)
- callbacks=[EarlyStoppingCallback()]


In [None]:
args = TrainingArguments(
    f"aias-vit-cifar-10", # output_dir
    overwrite_output_dir=True,
    # optimizer : optim = adamw_default(default)
    learning_rate=2e-5, #adamw
    #weight_decay=0.01,
    # batch_size
    per_device_train_batch_size=batch_size, # 
    per_device_eval_batch_size=batch_size,  # 
    #
    eval_strategy="steps", #"epoch", 
    eval_steps=steps_per_epoch, # 10 evaluation per this step
    logging_steps=steps_per_epoch//2, # 10 
    logging_dir='logs',
    # num_train_epochs=3, # num_train_epoch or max_steps 둘 중 하나 사용
    max_steps=steps_per_epoch * num_train_epochs, # train steps
    remove_unused_columns=False,
    # saving checkpoints
    save_strategy="no", # "steps", "epoch"
    # save_total_limit=1, # keep only the last checkpoint
    # load_best_model_at_end=True,
    # metric_for_best_model="accuracy",
)

# examples: single batch
def collate_fn(examples):
    pixels = torch.stack([example["pixels"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixels, "labels": labels}

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return dict(accuracy=accuracy_score(predictions, labels))

trainer = Trainer(
    ft_model,
    args, # TrainingArguments
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    # tokenizer=processor, # preprocessor
    processing_class=processor, # preprocessor
)

In [None]:
#?trainer.compute_loss

In [None]:
#args

### Training the model for fine-tuning

 ## 1. last-layer 만 먼저 fine-tuning하고 싶다면, 해당 layer만 trainable로 변경

In [None]:
for name, param in model.named_parameters():
    param.requires_grad = False
for name, param in model.classifier.named_parameters():
    param.requires_grad = True
    
# 각 파라미터 이름, shape, trainable 여부 출력
for name, param in model.named_parameters():
    print(f"{name:55} | shape: {str(param.shape):28} | trainable: {param.requires_grad}")

In [None]:
start_time = datetime.now()
result = trainer.train()
end_time = datetime.now()
elapsed_time = end_time - start_time

In [None]:
# 2. full-train(default), 전체 layer의 parameters를 trainable로 변경하여 전체 weight 학습

In [None]:
# for name, param in model.named_parameters():
#     param.requires_grad = True

# for name, param in model.named_parameters():
#     print(f"{name:55} | shape: {str(param.shape):28} | trainable: {param.requires_grad}")

In [None]:
# start_time = datetime.now()
# result2 = trainer.train()
# end_time = datetime.now()
# elapsed_time = end_time - start_time

In [None]:
print(f'training time: {str(elapsed_time)} sec')

In [None]:
def print_result(result):
    print(f"Time: {result.metrics['train_runtime']:.2f}")
    print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")

In [None]:
print_result(result)

In [None]:
trainer.state.log_history

In [None]:
trainer.state.log_history[-1]

In [None]:
# history
trainer.state.log_history
import pandas as pd
log_history = pd.DataFrame(trainer.state.log_history[:-1]) # to get log history except train results
log_history1 = log_history[log_history['loss'].notna()].dropna(axis=1, how='any')
log_history2 = log_history[log_history['eval_loss'].notna()].dropna(axis=1, how='any')

In [None]:
# log_history1

In [None]:
# log_history2

In [None]:
log_history1[['step', 'loss']].plot(x='step', y='loss')
plt.title('train loss vs steps')

In [None]:
log_history2[['step', 'eval_loss']].plot(x='step', y='eval_loss' )
log_history2[['step', 'eval_accuracy']].plot(x='step', y='eval_accuracy')

## Evaluation

In [None]:
# evaluation using trainer.predict API
# [NOTE] there are labels informations in the test_dataset
start_time = datetime.now()
outputs = trainer.predict(test_dataset)
end_time = datetime.now()
elapsed_time = end_time - start_time
print(f"Accuracy at test dataset: {outputs.metrics['test_accuracy']}")
print(f"Processing time to evaluate test dataset: {elapsed_time} sec")

In [None]:
# outputs에 대해 알아보기
print(f'accuracy={outputs.metrics['test_accuracy']}')

In [None]:
print(outputs.metrics) 
outputs.predictions.shape,outputs.label_ids.shape 
#dir(outputs)

### inference results for the selected index in the test_dataset

In [None]:
idx = 10
ex = test_dataset[idx]['pixels']
ex = (ex+1)/2

exi = ToPILImage()(ex)
display(exi)
print(f'predicted: {itos[np.argmax(outputs.predictions[idx])]}, ground truth: {itos[outputs.label_ids[idx]]}')

### Confusion Matrix

In [None]:
y_true = outputs.label_ids
y_pred = outputs.predictions.argmax(1)

labels = train_dataset.features['label'].names
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot(xticks_rotation=45)

## Save the model

In [None]:
output_dir="vit-cifar10"
trainer.save_model(output_dir)

### Load the model and verify it

In [None]:
from transformers import AutoImageProcessor, AutoModelForImageClassification
re_processor = AutoImageProcessor.from_pretrained(output_dir, use_fast=True)
re_model = AutoModelForImageClassification.from_pretrained(output_dir, device_map=device)
print(re_model.classifier) #The google/vit-base-patch16-224 model is originally fine tuned on imagenet-1K with 1000 output classes
# or
#re_model2 = ViTForImageClassification.from_pretrained(output_dir, device_map=device)

In [None]:
re_model

In [None]:
re_processor