<a href="https://colab.research.google.com/github/mirpouya/Transformer_EDU/blob/main/Vision_Transformer_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets

In [2]:
!pip install datasets transformers torch



In [3]:
# import CIFAR-10 dataset from HuggingFace
from datasets import load_dataset

# train split
dataset_train = load_dataset(
    "cifar10",
    split = "train",
    ignore_verifications = False  # set to True if seeing splits error
)



In [4]:
dataset_train

Dataset({
    features: ['img', 'label'],
    num_rows: 50000
})

In [5]:
# test split
dataset_test = load_dataset(
    "cifar10",
    split = "test",
    ignore_verifications = True
)



In [6]:
dataset_test

Dataset({
    features: ['img', 'label'],
    num_rows: 10000
})

In [7]:
# check how many labels we have - number of classes
num_classes = len(set(dataset_test["label"]))
labels = dataset_test.features["label"]

num_classes, labels

(10,
 ClassLabel(names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], id=None))

In [8]:
dataset_train[0]

{'img': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32 at 0x7F87DC9D4610>,
 'label': 0}

In [9]:
dataset_train[0]["label"], labels.names[dataset_train[0]["label"]]

(0, 'airplane')

## <b> Loading Vision Transformer Feature Extractor </b>

In [10]:
from transformers import ViTFeatureExtractor

In [11]:
# import model
model_name_or_path = "google/vit-base-patch16-224-in21k"

feature_extractor = ViTFeatureExtractor.from_pretrained(
    model_name_or_path
)



In [12]:
feature_extractor

ViTFeatureExtractor {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTFeatureExtractor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

In [13]:
# testing feature extractor
example = feature_extractor(
    dataset_train[0]["img"],
    return_tensors = "pt"
)

In [14]:
example

{'pixel_values': tensor([[[[ 0.3961,  0.3961,  0.3961,  ...,  0.2941,  0.2941,  0.2941],
          [ 0.3961,  0.3961,  0.3961,  ...,  0.2941,  0.2941,  0.2941],
          [ 0.3961,  0.3961,  0.3961,  ...,  0.2941,  0.2941,  0.2941],
          ...,
          [-0.1922, -0.1922, -0.1922,  ..., -0.2863, -0.2863, -0.2863],
          [-0.1922, -0.1922, -0.1922,  ..., -0.2863, -0.2863, -0.2863],
          [-0.1922, -0.1922, -0.1922,  ..., -0.2863, -0.2863, -0.2863]],

         [[ 0.3804,  0.3804,  0.3804,  ...,  0.2784,  0.2784,  0.2784],
          [ 0.3804,  0.3804,  0.3804,  ...,  0.2784,  0.2784,  0.2784],
          [ 0.3804,  0.3804,  0.3804,  ...,  0.2784,  0.2784,  0.2784],
          ...,
          [-0.2471, -0.2471, -0.2471,  ..., -0.3412, -0.3412, -0.3412],
          [-0.2471, -0.2471, -0.2471,  ..., -0.3412, -0.3412, -0.3412],
          [-0.2471, -0.2471, -0.2471,  ..., -0.3412, -0.3412, -0.3412]],

         [[ 0.4824,  0.4824,  0.4824,  ...,  0.3647,  0.3647,  0.3647],
          [ 0

In [15]:
import numpy as np
np.asarray(example["pixel_values"]).shape

(1, 3, 224, 224)

In [16]:
# original shape of image
dataset_train[0]["img"].size

(32, 32)

### <b> Switch to GPU </b>

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [18]:
# preprocessing

def preprocessing(batch):
  # takes a list of PIL images and change them to pixel values
  inputs = feature_extractor(
      batch["img"],
      return_tensors = "pt"
  ).to(device)

  # including the labels
  inputs["labels"] = batch["lebel"]
  return inputs

In [19]:
# collate function to create batch tensors
def collate_fn(batch):

  return {
      "pixel_values": torch.stack([x["pixel_values"] for x in batch]),
      "labels": torch.tensor([x["label"] for x in batch])
  }

In [20]:
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")

def compute_metrics(p):
  return metric.compute(
      predictions = np.argmax(p.predictions, axis = 1),
      references = p.label_ids
  )

  metric = load_metric("accuracy")


In [21]:
!pip install transformers[torch]



In [22]:
!pip install accelerate -U



In [23]:
# set training arguments

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir = "./cifar",
    per_device_train_batch_size = 16,
    evaluation_strategy = "steps",
    num_train_epochs = 4,
    # fp16 = True,
    save_steps = 100,
    eval_steps = 100,
    logging_steps = 10,
    learning_rate = 2e-4,
    save_total_limit = 2,
    remove_unused_columns = False,
    push_to_hub = False,
    load_best_model_at_end = True,
)

In [24]:
# import vision transformer for image classification

from transformers import ViTForImageClassification

labels = dataset_train.features["label"].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels = len(labels)
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
