In [None]:
%%capture
!pip install ftfy regex tqdm matplotlib opencv-python torch scipy scikit-image datasets transformers transformers[torch] accelerate -U
!pip install git+https://github.com/openai/CLIP.git
!pip install tqdm
!pip install torchvision tensorboard pillow
!pip install evaluate
!pip install einops
!pip install timm
!pip install captum

import urllib.request
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
from scipy.ndimage import filters
from torch import nn

from google.colab import drive
drive.mount('/content/drive')

In [None]:
# weird error that poped up in (import clip) fix patch. Don't run if not necessary
!pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchtext==0.14.1 torchaudio==0.13.1 torchdata==0.5.1 --extra-index-url https://download.pytorch.org/whl/cu117

# Load Dataset

In [None]:
import datasets
from datasets import load_dataset
from torchvision import transforms

train_ds = load_dataset('food101', split="train")
validation_ds = load_dataset('food101', split="validation")

exclude_idx = []
exclude_idx2 = []
for x in range(len(train_ds)):
  if train_ds[x]["image"].mode != 'RGB':
    exclude_idx.append(x)

for x in range(len(validation_ds)):
  if validation_ds[x]["image"].mode != 'RGB':
    exclude_idx2.append(x)


In [None]:
# create new dataset exluding those idx
train_ds_new = train_ds.select(
    (
        i for i in range(len(train_ds))
        if i not in set(exclude_idx)
    )
)

validation_ds_new = validation_ds.select(
    (
        i for i in range(len(validation_ds))
        if i not in set(exclude_idx2)
    )
)

ds = datasets.DatasetDict({"train":train_ds_new,"validation":validation_ds_new})



# ViT Implementation & Training

In [None]:
from transformers import Swinv2ForImageClassification, AutoImageProcessor, TrainingArguments, Trainer
from datasets import load_dataset, load_metric
import torch

# Load and preprocess the dataset
processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")

train_dataset = train_ds_new
eval_dataset = validation_ds_new

def process_example(example_batch):
    inputs = processor([x for x in example_batch['image']], return_tensors='pt')
    inputs['label'] = example_batch['label']
    return inputs

def transform(example_batch):
    # Taking a list of PIL images and turn them to pixel values
    inputs = processor([x for x in example_batch['image']], return_tensors='pt')
    # Including the labels
    inputs['label'] = example_batch['label']
    return inputs

prepared_ds = ds.with_transform(transform)

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

labels = ds['train'].features['label'].names

from transformers import AutoImageProcessor, AutoModelForImageClassification
# model_vit = AutoModel.from_pretrained("/content/drive/MyDrive/vit-food-swinv2/vit-food-v1/checkpoint-142050/")

model_vit = AutoModelForImageClassification.from_pretrained(
    "/content/drive/MyDrive/vit-food-swinv2/vit-food-v1/checkpoint-142050/",
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

feature_extractor = AutoImageProcessor.from_pretrained(
    "/content/drive/MyDrive/vit-food-swinv2/vit-food-v1/checkpoint-142050/",
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

training_args = TrainingArguments(
  output_dir="./vit-food-swinv2",
  per_device_train_batch_size=16,
  evaluation_strategy="epoch",
  num_train_epochs=15,
  logging_steps=10,
  learning_rate=2e-4,
  save_strategy="epoch",
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True
)

trainer = Trainer(
    model=model_vit,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    tokenizer=processor,
)

# Train the model
train_results=trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

metrics = trainer.evaluate(prepared_ds['validation'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

# ViT Predictions
predictions = trainer.predict(prepared_ds['validation'])
vit_predictions = predictions.predictions.argmax(-1)

In [None]:
!cp -r vit-food-v1/ ../content/drive/MyDrive/vit-food-swinv2

In [None]:
!zip -r vit-food-swinv2.zip vit-food-v1
from google.colab import files
files.download('vit-food-swinv2.zip')