<a target="_blank" href="https://colab.research.google.com/github/duoan/ReplicateAI/blob/master/stage2_representation/2020_VisionTransformer/notebook/Vision%20Transformer_demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# 🧠 ReplicateAI Demo Notebook — Vision Transformer (2021)
## An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale (2020)

> https://arxiv.org/abs/2010.11929

![ViT](../figures/vit.png)

In [1]:
print('Notebook environment ready.')

Notebook environment ready.


In [2]:
%load_ext autoreload

In [3]:
import os
import sys

sys.path.insert(0, os.path.abspath('../src'))

# Training

## Import Depedencies

In [4]:
import os
import numpy as np
import torch
from sklearn.metrics import accuracy_score, precision_score, recall_score
from torch import stack, tensor
from torchvision import datasets, transforms
from transformers import TrainingArguments, Trainer

from torchsummary import summary
from torchview import draw_graph

from model import ViTConfig, VisionTransformer

  from .autonotebook import tqdm as notebook_tqdm


## Create Model

In [5]:
config = ViTConfig(
    num_classes=10,
)
model = VisionTransformer(config)

## Load Dataset
We are going to use [imagenette](https://github.com/fastai/imagenette#imagenette-1) is a subset of 10 easily classified classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute).

> This give us easily to train our model with lower cost, and deep dive the model design.

In [6]:
# Define the root directory where the dataset will be downloaded
DATA_ROOT = "./data/imagenette"
os.makedirs(DATA_ROOT, exist_ok=True)

# Standard ViT/ImageNet preprocessing pipeline (224x224 output)
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    # Normalization is crucial for ViT, even for testing
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load the ImageNette dataset
full_train_ds = datasets.Imagenette(
    root=DATA_ROOT, split='train', size='full', download=True, transform=transform
)
full_test_ds = datasets.Imagenette(
    root=DATA_ROOT, split='val', size='full', download=True, transform=transform
)

In [7]:
batch_size = 32
input_size = (batch_size, 3, 224, 224)

## Visualizing the Model Structure and Parameters

In [10]:
draw_graph(
    model,
    input_size=input_size,
    device='cpu',
    expand_nested=True,
    save_graph=True,
    graph_dir="BT",
    filename=str(os.path.abspath('../figures/model')),
)

<torchview.computation_graph.ComputationGraph at 0x10c194350>

![model](../figures/model.png)

In [9]:
summary(model, input_size=(3, 224, 224), batch_size=batch_size, device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [32, 768, 14, 14]         590,592
           Flatten-2             [32, 768, 196]               0
 ViTPatchEmbedding-3             [32, 196, 768]               0
           Dropout-4             [32, 197, 768]               0
      ViTEmbedding-5             [32, 197, 768]               0
            Linear-6             [32, 197, 768]         589,824
            Linear-7             [32, 197, 768]         589,824
            Linear-8             [32, 197, 768]         589,824
           Dropout-9         [32, 12, 197, 197]               0
           Linear-10             [32, 197, 768]         590,592
          Dropout-11             [32, 197, 768]               0
     ViTAttention-12             [32, 197, 768]               0
        LayerNorm-13             [32, 197, 768]           1,536
           Linear-14            [32, 19

## Define the metrics and data collector

In [11]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # Labels from torchvision dataset subset are typically tensors/arrays, not tuples
    # Ensure labels are extracted correctly (Trainer handles this mostly, but good practice)
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()

    preds = np.argmax(logits, axis=1)
    return {
        "accuracy": accuracy_score(labels, preds),
        "precision": precision_score(labels, preds, average='weighted', zero_division=0),
        "recall": recall_score(labels, preds, average='weighted', zero_division=0),
    }

def custom_data_collator(samples):
    """
    Collate function to handle (image, label) tuples from torchvision datasets
    and format them into a dictionary for the Hugging Face Trainer.
    """
    # 1. Unpack the list of (image, label) tuples
    images = [sample[0] for sample in samples]
    labels = [sample[1] for sample in samples]

    # 2. Stack the tensors and convert the labels to a tensor
    # torch.stack creates a batch tensor from a list of tensors (for images)
    # torch.tensor converts a list of numbers to a tensor (for labels)

    return {
        # The Trainer expects the input image tensor under the key 'pixel_values'
        "pixel_values": stack(images),
        # The Trainer expects the labels under the key 'labels'
        "labels": tensor(labels),
    }

## Define the trainer
Here we directly use huggingface transformers Trainer to quickly train our model, which enables us can run the training code in different settings including local, single GPU, multi GPU, multi nodes.


In [12]:
device = torch.accelerator.current_accelerator()

args = TrainingArguments(
    "vit-imagenette-macbook-test",  # Unique name for test runs
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    eval_strategy="epoch",
    save_strategy="no",  # No need to save checkpoints for a quick test
    num_train_epochs=2,  # Run for only 2 epochs
    learning_rate=3e-4,
    weight_decay=0.05,
    logging_steps=10,
    load_best_model_at_end=False,  # No need to load best model
    report_to="none",
    # Pass the device to TrainingArguments (important for non-CUDA systems)
    optim="adamw_torch",  # Explicitly use a common optimizer
    disable_tqdm=False,  # Keep progress bar for visibility
    dataloader_pin_memory=True if torch.cuda.is_available() else False,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=full_train_ds,
    eval_dataset=full_test_ds,
    compute_metrics=compute_metrics,
    data_collator=custom_data_collator,
)

# Manually move model to the determined device (Trainer will handle it, but this is a good check)
model.to(device)

VisionTransformer(
  (embedding): ViTEmbedding(
    (patch_embeddings): ViTPatchEmbedding(
      (projector): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (flatten): Flatten(start_dim=2, end_dim=-1)
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layers): ModuleList(
      (0-11): 12 x Sequential(
        (0): ViTAttention(
          (attn_drop): Dropout(p=0.0, inplace=False)
          (norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (query): Linear(in_features=768, out_features=768, bias=False)
          (key): Linear(in_features=768, out_features=768, bias=False)
          (value): Linear(in_features=768, out_features=768, bias=False)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_drop): Dropout(p=0.0, inplace=False)
        )
        (1): ViTFeedForward(
          (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (net): Sequential(
            (

## Train the model

In [13]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall
1,574.5496,291.497375,0.107516,0.107326,0.107516
2,172.6001,173.954941,0.098089,0.098391,0.098089


TrainOutput(global_step=592, training_loss=240.71912934006872, metrics={'train_runtime': 4836.2851, 'train_samples_per_second': 3.916, 'train_steps_per_second': 0.122, 'total_flos': 0.0, 'train_loss': 240.71912934006872, 'epoch': 2.0})