<a href="https://colab.research.google.com/github/dmaresza/PyTorch-Course/blob/main/FoodVision_ViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Notebook to create a Vision Transformer (ViT) feature extraction model, train it on the Food101 dataset, and deploy it to HuggingFace spaces using Gradio.

## Setup

In [1]:
# Getting necessary imports
import torch
import torchvision

from torch import nn
from torchvision import transforms

try:
  from torchinfo import summary
except:
  print("[INFO] Couldn't find torchinfo... installing it.")
  !pip install -q torchinfo
  from torchinfo import summary

# Check torch version
torch.__version__

[INFO] Couldn't find torchinfo... installing it.


'2.4.0+cu121'

In [2]:
# some Colab shell commands weren't working so had to add this bit
import locale
locale.getpreferredencoding = lambda: "UTF-8"

In [3]:
# Setting device globally
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

## Instantiate model

In [4]:
# Create ViT_B_16 pretrained weights, transforms and model
vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
vit_transforms = vit_weights.transforms()
vit_model = torchvision.models.vit_b_16(weights=vit_weights)

# Freeze all of the base layers
for param in vit_model.parameters():
  param.requires_grad = False

# Change classifier head to suit our needs
vit_model.heads = nn.Sequential(
    nn.Linear(in_features=768,
              out_features=101))

# Get model summary
summary(model=vit_model,
        input_size=(1, 3, 224, 224))

Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:01<00:00, 210MB/s]


Layer (type:depth-idx)                        Output Shape              Param #
VisionTransformer                             [1, 101]                  768
├─Conv2d: 1-1                                 [1, 768, 14, 14]          (590,592)
├─Encoder: 1-2                                [1, 197, 768]             151,296
│    └─Dropout: 2-1                           [1, 197, 768]             --
│    └─Sequential: 2-2                        [1, 197, 768]             --
│    │    └─EncoderBlock: 3-1                 [1, 197, 768]             (7,087,872)
│    │    └─EncoderBlock: 3-2                 [1, 197, 768]             (7,087,872)
│    │    └─EncoderBlock: 3-3                 [1, 197, 768]             (7,087,872)
│    │    └─EncoderBlock: 3-4                 [1, 197, 768]             (7,087,872)
│    │    └─EncoderBlock: 3-5                 [1, 197, 768]             (7,087,872)
│    │    └─EncoderBlock: 3-6                 [1, 197, 768]             (7,087,872)
│    │    └─EncoderBlock: 3-

## Get dataset

In [5]:
from torchvision import datasets
from pathlib import Path

# Set up data directory
data_dir = Path("data")

# Create train dataset
train_data = datasets.Food101(root=data_dir,
                              split="train",
                              transform=vit_transforms,
                              download=True)

# Create test dataset
test_data = datasets.Food101(root=data_dir,
                             split="test",
                             transform=vit_transforms,
                             download=True)

len(train_data), len(test_data)

Downloading https://data.vision.ee.ethz.ch/cvl/food-101.tar.gz to data/food-101.tar.gz


100%|██████████| 4996278331/4996278331 [03:50<00:00, 21686204.29it/s]


Extracting data/food-101.tar.gz to data


(75750, 25250)

In [6]:
import os

# Large batch size to make better use of GPU
BATCH_SIZE = 512
NUM_WORKERS = os.cpu_count()

train_dataloader = torch.utils.data.DataLoader(dataset=train_data,
                                               batch_size=BATCH_SIZE,
                                               shuffle=True,
                                               num_workers=NUM_WORKERS)

test_dataloader = torch.utils.data.DataLoader(dataset=test_data,
                                              batch_size=BATCH_SIZE,
                                              shuffle=False,
                                              num_workers=NUM_WORKERS)

len(train_dataloader), len(test_dataloader)

(148, 50)

## Train

In [7]:
# Set up optimizer
optimizer = torch.optim.Adam(params=vit_model.parameters())

# Set up loss function
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.1)

In [8]:
from tqdm.auto import tqdm

# 10 epochs to give model sufficient time to train, but not take TOO long
EPOCHS = 10

vit_model.to(device)

for epoch in tqdm(range(EPOCHS)):
  ### TRAINING
  vit_model.train()

  train_loss, train_acc = 0, 0

  for batch, (X, y) in enumerate(train_dataloader):
    # Forward pass
    X, y = X.to(device), y.to(device)
    y_pred = vit_model(X)

    # Calculate & accumulate loss
    loss = loss_fn(y_pred, y)
    train_loss += loss.item()

    optimizer.zero_grad()

    loss.backward()

    optimizer.step()

    # Calculate & accumulate accuracy
    y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
    train_acc += (y_pred_class == y).sum().item()/len(y_pred)

  # Get average train loss & average train accuracy for 1 epoch
  train_loss /= len(train_dataloader)
  train_acc /= len(train_dataloader)

  ### TESTING
  vit_model.eval()

  test_loss, test_acc = 0, 0

  with torch.inference_mode():
    for batch, (X, y) in enumerate(test_dataloader):
      X, y = X.to(device), y.to(device)
      test_pred_logits = vit_model(X)

      # Calculate & accumulate loss
      loss = loss_fn(test_pred_logits, y)
      test_loss += loss.item()

      # Calculate & accumulate accuracy
      test_pred_labels = test_pred_logits.argmax(dim=1)
      test_acc += (test_pred_labels == y).sum().item()/len(test_pred_labels)

  # Get average test loss & average test accuracy for 1 epoch
  test_loss /= len(test_dataloader)
  test_acc /= len(test_dataloader)

  # Print out results for each epoch
  print(
      f"Epoch: {epoch + 1} | "
      f"train_loss: {train_loss:.4f} | "
      f"train_acc: {train_acc:.4f} | "
      f"test_loss: {test_loss:.4f} | "
      f"test_acc: {test_acc:.4f}"
  )

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 2.5963 | train_acc: 0.5122 | test_loss: 1.9517 | test_acc: 0.6732
Epoch: 2 | train_loss: 1.9912 | train_acc: 0.6638 | test_loss: 1.7974 | test_acc: 0.7219
Epoch: 3 | train_loss: 1.8673 | train_acc: 0.7021 | test_loss: 1.7328 | test_acc: 0.7411
Epoch: 4 | train_loss: 1.7953 | train_acc: 0.7271 | test_loss: 1.6944 | test_acc: 0.7516
Epoch: 5 | train_loss: 1.7476 | train_acc: 0.7419 | test_loss: 1.6703 | test_acc: 0.7591
Epoch: 6 | train_loss: 1.7116 | train_acc: 0.7529 | test_loss: 1.6557 | test_acc: 0.7615
Epoch: 7 | train_loss: 1.6847 | train_acc: 0.7616 | test_loss: 1.6444 | test_acc: 0.7668
Epoch: 8 | train_loss: 1.6616 | train_acc: 0.7707 | test_loss: 1.6369 | test_acc: 0.7683
Epoch: 9 | train_loss: 1.6426 | train_acc: 0.7770 | test_loss: 1.6317 | test_acc: 0.7703
Epoch: 10 | train_loss: 1.6268 | train_acc: 0.7825 | test_loss: 1.6254 | test_acc: 0.7722


## Save the model

In [9]:
model_path = Path("models")
model_path.mkdir(parents=True,
                 exist_ok=True)

model_name = "pretrained_vit_feature_extractor_food101.pth"
model_save_path = model_path / model_name

torch.save(obj=vit_model.state_dict(),
           f=model_save_path)

## Build Gradio App

In [10]:
import shutil

# Create directory for demo files
demo_path = Path("demos/food101/")

if demo_path.exists():
  shutil.rmtree(demo_path)
  demo_path.mkdir(parents=True, exist_ok=True)
else:
  demo_path.mkdir(parents=True, exist_ok=True)

!ls demos/food101/

In [11]:
# Add a few example images
demo_examples_path = demo_path / "examples"
demo_examples_path.mkdir(parents=True, exist_ok=True)

examples = [Path('data/food-101/images/cannoli/2034686.jpg'),
            Path('data/food-101/images/guacamole/36147.jpg'),
            Path('data/food-101/images/steak/1053665.jpg')]

for example in examples:
  destination = demo_examples_path / example.name
  print(f"[INFO] Copying {example} to {destination}")
  shutil.copy2(src=example, dst=destination)

[INFO] Copying data/food-101/images/cannoli/2034686.jpg to demos/food101/examples/2034686.jpg
[INFO] Copying data/food-101/images/guacamole/36147.jpg to demos/food101/examples/36147.jpg
[INFO] Copying data/food-101/images/steak/1053665.jpg to demos/food101/examples/1053665.jpg


In [12]:
class_names = train_data.classes

# Create path to Food101 class names
class_names_path = demo_path / "class_names.txt"

# Write Food101 class names to text file
with open(class_names_path, "w") as f:
  print(f"[INFO] Saving Food101 class names to {class_names_path}")
  f.write("\n".join(class_names)) # new line per class name

[INFO] Saving Food101 class names to demos/food101/class_names.txt


In [13]:
%%writefile demos/food101/model.py
# File for instantiating ViT model
import torch
import torchvision
from torch import nn

def create_vit_model(num_classes:int=3,
                     seed:int=42):
  # Create ViT_B_16 pretrained weights, transforms and model
  weights = torchvision.models.ViT_B_16_Weights.DEFAULT
  transforms = weights.transforms()
  model = torchvision.models.vit_b_16(weights=weights)

  # Freeze all of the base layers
  for param in model.parameters():
    param.requires_grad = False

  # Change classifier head to suit our needs
  model.heads = nn.Sequential(
      nn.Linear(in_features=768,
                out_features=num_classes)
  )

  return model, transforms

Writing demos/food101/model.py


In [14]:
# Move saved model into demo folder
!mv models/pretrained_vit_feature_extractor_food101.pth demos/food101/

In [15]:
%%writefile demos/food101/app.py
# File to build Gradio application
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch

from model import create_vit_model
from timeit import default_timer as timer
from typing import Tuple, Dict

# Set up class names
with open("class_names.txt", "r") as f:
  class_names = [food_name.strip() for food_name in f.readlines()]

### 2. Model and transforms preparation ###
# Create model and transforms
vit, vit_transforms = create_vit_model(
    num_classes=len(class_names))

# Load saved weights
vit.load_state_dict(
    torch.load(f="pretrained_vit_feature_extractor_food101.pth",
               map_location=torch.device("cpu")) # load the model to the CPU
)

### 3. Predict function ###
def predict(img) -> Tuple[Dict, float]:
  # Start a timer
  start_time = timer()

  # Transform the input image for use with ViT
  img = vit_transforms(img).unsqueeze(0) # unsqueeze = add batch dimension on 0th index

  # Put model inot eval mode, make prediction
  vit.eval()
  with torch.inference_mode():
    # Pass transformed image through the model and turn the prediction logits into probabilities
    pred_probs = torch.softmax(vit(img), dim=1)

  # Create a prediction label and prediction probability dictionary
  pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}

  # Calculate pred time
  end_time = timer()
  pred_time = round(end_time - start_time, 4)

  # Return pred dict and pred time
  return pred_labels_and_probs, pred_time

### 4. Gradio app ###

# Create title, description, and article
title = "FoodVision BIG 🍔👁"
description = "A [Vision Transformer (ViT)](https://pytorch.org/vision/stable/models/generated/torchvision.models.vit_b_16.html#torchvision.models.vit_b_16) feature extractor computer vision model to classify images of 101 classes of food from the [Food101 dataset](https://www.kaggle.com/datasets/dansbecker/food-101/data)."
article = "Created at [FoodVision](https://github.com/dmaresza/PyTorch-Course/blob/main/FoodVision_ViT.ipynb)."

# Create example list
example_list = [["examples/" + example] for example in os.listdir("examples")]

# Create the Gradio demo
demo = gr.Interface(fn=predict, # maps inputs to outputs
                    inputs=gr.Image(type="pil"),
                    outputs=[gr.Label(num_top_classes=5, label="Predictions"),
                             gr.Number(label="Prediction time (s)")],
                    examples=example_list,
                    title=title,
                    description=description,
                    article=article)

# Launch the demo
demo.launch(debug=False)

Writing demos/food101/app.py


In [16]:
%%writefile demos/food101/requirements.txt
torch==2.4.0
torchvision==0.19.0
gradio==4.42.0

Writing demos/food101/requirements.txt


In [17]:
# Change into the food101 directory and then zip it from the inside
!cd demos/food101 && zip -r ../foodvision.zip * -x "*.pyc" "*.ipynb" "*__pycache__*" "*ipynb_checkpoints*"

  adding: app.py (deflated 54%)
  adding: class_names.txt (deflated 48%)
  adding: examples/ (stored 0%)
  adding: examples/36147.jpg (deflated 1%)
  adding: examples/2034686.jpg (deflated 1%)
  adding: examples/1053665.jpg (deflated 1%)
  adding: model.py (deflated 47%)
  adding: pretrained_vit_feature_extractor_food101.pth (deflated 7%)
  adding: requirements.txt (deflated 6%)


In [19]:
# Download
try:
  from google.colab import files
  files.download("demos/foodvision.zip")
except:
  print(f"Not running in Google Colab, can't use google.colab.files.download(), please download foodvision.zip manually.")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>