# Exploration and Comparison of Transformers for Image Classification

## DeiT (Data efficient image Transformer)


### Prerequisites

In [2]:
import os
os.chdir('..')

import torch
import torch.nn as nn

from transformers import AutoImageProcessor, DeiTImageProcessor
from datasets import load_dataset, concatenate_datasets

from src.dataset_builder import ImageDataset
from src.models import Backbone
from src.train import train_model, evaluate_model

from utils.config import Config
from utils.train_utils import *
from utils.models_utils import *

### GPU

In [3]:
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        vram = torch.cuda.get_device_properties(i).total_memory / (1024 ** 3)
        print(f"  VRAM: {vram:.2f} GB")
else:
    print("CUDA is not available.")

GPU 0: NVIDIA A100 80GB PCIe MIG 1g.10gb
  VRAM: 9.50 GB


#### Data preparation

In [4]:
train, val, test = load_dataset('timm/resisc45', split=['train', 'validation', 'test'])

In [10]:
processor = AutoImageProcessor.from_pretrained(model_names['DeiT'])

In [11]:
train_split = ImageDataset(dataset=train, processor=processor)
val_split = ImageDataset(dataset=val, processor=processor)
test_split = ImageDataset(dataset=test, processor=processor)

In [12]:
num_classes = train_split.get_num_classes()

### Model

In [13]:
config = Config()
student = Backbone(model_name=model_names['DeiT'], num_classes=num_classes)
teacher = Backbone(model_name=model_names['RegNet'], num_classes=num_classes)

config.json:   0%|          | 0.00/69.6k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/89.8M [00:00<?, ?B/s]

Some weights of DeiTForImageClassificationWithTeacher were not initialized from the model checkpoint at facebook/deit-small-distilled-patch16-224 and are newly initialized because the shapes did not match:
- cls_classifier.weight: found shape torch.Size([1000, 384]) in the checkpoint and torch.Size([45, 384]) in the model instantiated
- cls_classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([45]) in the model instantiated
- distillation_classifier.weight: found shape torch.Size([1000, 384]) in the checkpoint and torch.Size([45, 384]) in the model instantiated
- distillation_classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([45]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


config.json:   0%|          | 0.00/69.6k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/89.0M [00:00<?, ?B/s]

Some weights of RegNetForImageClassification were not initialized from the model checkpoint at facebook/regnet-x-040 and are newly initialized because the shapes did not match:
- classifier.1.weight: found shape torch.Size([1000, 1360]) in the checkpoint and torch.Size([45, 1360]) in the model instantiated
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([45]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
get_model_params(student)
get_model_params(teacher)

Parameters: 21.70M
Parameters: 20.82M


### Fine tuning #1 - Teacher

In [15]:
train_model(
    teacher,
    train_split,
    val_split,
    config,
    architecture='regnet',
    fine_tune=True,
)

Train: 100%|██████████| 2363/2363 [06:21<00:00,  6.19it/s]
Val: 100%|██████████| 788/788 [00:48<00:00, 16.16it/s]


Epochs: 1/2 | train_loss: 1.0008 | train_acc: 0.7732 | val_loss: 0.2507 | val_acc: 0.9267


Train: 100%|██████████| 2363/2363 [06:19<00:00,  6.22it/s]
Val: 100%|██████████| 788/788 [00:45<00:00, 17.31it/s]

Epochs: 2/2 | train_loss: 0.2914 | train_acc: 0.9233 | val_loss: 0.1859 | val_acc: 0.9475





### Evaluation # 1 - Teacher

In [16]:
evaluate_model(
    teacher,
    test_split,
    config,
)

Test: 100%|██████████| 788/788 [00:46<00:00, 17.10it/s]

test_loss: 0.2087 | test_acc: 0.9405





### Fine tuning #2 - DeiT with RegNet teacher

In [21]:
train_model(
    student,
    train_split,
    val_split,
    config,
    architecture='deit',
    fine_tune=True,
    with_distillation=True,
    teacher=teacher,
)

Train: 100%|██████████| 2363/2363 [07:29<00:00,  5.26it/s]
Val: 100%|██████████| 788/788 [01:20<00:00,  9.78it/s]


Epochs: 1/2 | train_loss: 0.5964 | train_acc: 0.8470 | val_loss: 0.2798 | val_acc: 0.9200


Train: 100%|██████████| 2363/2363 [07:37<00:00,  5.16it/s]
Val: 100%|██████████| 788/788 [01:23<00:00,  9.47it/s]

Epochs: 2/2 | train_loss: 0.1572 | train_acc: 0.9529 | val_loss: 0.2475 | val_acc: 0.9243





### Evaluation # 2 - DeiT with RegNet teacher

In [22]:
evaluate_model(
    student,
    test_split,
    config,
)

Test: 100%|██████████| 788/788 [00:49<00:00, 15.76it/s]

test_loss: 0.2745 | test_acc: 0.9183



