In [None]:
!pip install -q datasets transformers

In [None]:
import os
from huggingface_hub import notebook_login
import numpy as np
from datasets import load_dataset
from datasets import load_metric
from transformers import ViTFeatureExtractor
from transformers import ViTForImageClassification
from transformers import TrainingArguments
from transformers import Trainer

import torch
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

In [None]:
notebook_login()

In [None]:
%%capture
!sudo apt -qq install git-lfs
!git config --global credential.helper store

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip "/content/drive/MyDrive/food.zip" -d /tmp/foodimg

In [None]:
ds = load_dataset("imagefolder", data_dir="/tmp/foodimg")
ds = ds['train']

In [None]:
data = ds.train_test_split(test_size=0.10)

In [None]:
data

###Pushing data to Huggingface public hub

In [None]:
data.push_to_hub("lordgrim18/indian_food_images")

###Accessing the dataset from the public hub

In [None]:
data = load_dataset("lordgrim18/indian_food_images")

In [None]:
ex = data['train'][400]
ex

In [None]:
image = ex['image']
image

In [None]:
labels = data['train'].features['label']
labels

In [None]:
labels.int2str(ex['label'])

In [None]:
metric = load_metric("accuracy")

In [None]:
data

In [None]:
labels = data["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label

In [None]:
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

In [None]:
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
train_transforms = Compose(
        [
            RandomResizedCrop(feature_extractor.size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )
val_transforms = Compose(
        [
            RandomResizedCrop(feature_extractor.size),
            #CenterCrop(feature_extractor.size),
            ToTensor(),
            normalize,
        ]
    )

In [None]:
def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    example_batch["pixel_values"] = [
        train_transforms(image.convert("RGB")) for image in example_batch["image"]
    ]
    return example_batch

In [None]:
def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

In [None]:
# split up training into training + validation
train_ds = data['train']
val_ds = data['test']

In [None]:
train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)

In [None]:
train_ds[1]

In [None]:
model_name_or_path = 'google/vit-base-patch16-224-in21k'
model = ViTForImageClassification.from_pretrained(
    model_name_or_path, 
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

In [None]:
training_args = TrainingArguments(
    'finetuned-indian-food',
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=True,
  report_to='tensorboard',
  load_best_model_at_end=True,
  hub_strategy="end"
)

In [None]:
def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

In [None]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

In [None]:
trainer = Trainer(
    model,
    training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)

In [None]:
train_results = trainer.train()

In [None]:
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

In [None]:
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

##uploading into the hub

In [None]:
kwargs = {
    "finetuned_from": model.config._name_or_path,
    "tasks": "image-classification",
    "dataset": 'indian_food_images',
    "tags": ['image-classification'],
}

if training_args.push_to_hub:
    trainer.push_to_hub('🍻 cheers', **kwargs)
else:
    trainer.create_model_card(**kwargs)

##testing the model by loading a single image

In [None]:
from PIL import Image
import requests
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import torch

In [None]:
url = 'https://cdn.pixabay.com/photo/2016/10/25/13/42/indian-1768906_960_720.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image

In [None]:
repo_name = "lordgrim18/finetuned-indian-food"

feature_extractor = AutoFeatureExtractor.from_pretrained(repo_name)
model = AutoModelForImageClassification.from_pretrained(repo_name)

In [None]:
# prepare image for the model
encoding = feature_extractor(image.convert("RGB"), return_tensors="pt")
print(encoding.pixel_values.shape)

In [None]:
# forward pass
with torch.no_grad():
  outputs = model(**encoding)
  logits = outputs.logits

In [None]:
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

#Pipeline
##we can also use pipeline to quickly access our model

In [None]:
from transformers import pipeline

In [None]:
pipe = pipeline("image-classification", "lordgrim18/finetuned-indian-food")

In [None]:
url = 'https://cdn.pixabay.com/photo/2016/10/25/13/42/indian-1768906_960_720.jpg'
image = Image.open(requests.get(url, stream=True).raw)

pipe(image)

In [None]:
pipe = pipeline("image-classification", 
                model=model,
                feature_extractor=feature_extractor)

In [None]:
pipe(image)