Run on google colab

In [1]:
!pip install datasets transformers > /dev/null

In [2]:
# First lets connect the Gdrive that contains the data
from google.colab import drive
from pathlib import Path

drive.mount('/content/drive', force_remount=True)
base = Path('/content/drive/MyDrive/cifar')

base.mkdir(exist_ok=True)

Mounted at /content/drive


In [3]:
import numpy as np
import pandas as pd
from datasets import load_dataset, load_metric
from transformers import ( ViTForImageClassification, ViTFeatureExtractor,
                           TrainingArguments, Trainer )

import torch
from torch.utils.data import DataLoader

In [4]:
data = load_dataset('cifar100')
splits = data['train'].train_test_split(test_size=0.3)
train_ds = splits['train'].select(range(2000))
val_ds = splits['test'].select(range(1000))

print(train_ds.shape,val_ds.shape)

Reusing dataset cifar100 (/root/.cache/huggingface/datasets/cifar100/cifar100/1.0.0/f365c8b725c23e8f0f8d725c3641234d9331cd2f62919d1381d1baa5b3ba3142)


  0%|          | 0/2 [00:00<?, ?it/s]

(2000, 3) (1000, 3)


In [5]:
labelk = 'coarse_label'
id2label = {id:label for id, label in enumerate(train_ds.features[labelk].names)}
label2id = {label:id for id,label in id2label.items()}

In [6]:
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
#model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=len(id2label), id2label=id2label, label2id=label2id);

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device);

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
def train_transforms(example_batch):
    # Straight application of feature extractor - nothing special (yet)
    inputs = feature_extractor([x for x in example_batch['img']], return_tensors='pt')
    inputs['labels'] = example_batch['coarse_label']
    return inputs

metric_name = "accuracy"
metric = load_metric(metric_name)
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

In [8]:
prepped_train = train_ds.with_transform(train_transforms)
prepped_val = val_ds.with_transform(train_transforms)

In [10]:
args = TrainingArguments(
  output_dir=str(base / "vit_basic"),
  per_device_train_batch_size=10,
  evaluation_strategy="epoch", #"steps",
  num_train_epochs=2,
  #fp16=True,
  save_steps=100,
  #eval_steps=100,
  logging_steps=5,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
)

In [11]:
trainer = Trainer(
    model=model,
    args=args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepped_train,
    eval_dataset=prepped_val,
    tokenizer=feature_extractor,
)

In [12]:
train_results = trainer.train()
trainer.save_model(str(base / "model0"))
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

***** Running training *****
  Num examples = 2000
  Num Epochs = 2
  Instantaneous batch size per device = 10
  Total train batch size (w. parallel, distributed & accumulation) = 10
  Gradient Accumulation steps = 1
  Total optimization steps = 400


Epoch,Training Loss,Validation Loss,Accuracy
1,1.1413,1.174239,0.73
2,0.4759,0.694833,0.853


Saving model checkpoint to /content/drive/MyDrive/cifar/vit_basic/checkpoint-100
Configuration saved in /content/drive/MyDrive/cifar/vit_basic/checkpoint-100/config.json
Model weights saved in /content/drive/MyDrive/cifar/vit_basic/checkpoint-100/pytorch_model.bin
Feature extractor saved in /content/drive/MyDrive/cifar/vit_basic/checkpoint-100/preprocessor_config.json
Saving model checkpoint to /content/drive/MyDrive/cifar/vit_basic/checkpoint-200
Configuration saved in /content/drive/MyDrive/cifar/vit_basic/checkpoint-200/config.json
Model weights saved in /content/drive/MyDrive/cifar/vit_basic/checkpoint-200/pytorch_model.bin
Feature extractor saved in /content/drive/MyDrive/cifar/vit_basic/checkpoint-200/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8
Saving model checkpoint to /content/drive/MyDrive/cifar/vit_basic/checkpoint-300
Configuration saved in /content/drive/MyDrive/cifar/vit_basic/checkpoint-300/config.json
Model weights save

***** train metrics *****
  epoch                    =         2.0
  total_flos               = 288726729GF
  train_loss               =      1.3114
  train_runtime            =  2:25:23.47
  train_samples_per_second =       0.459
  train_steps_per_second   =       0.046


In [13]:
metrics = trainer.evaluate(prepped_val)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8


***** eval metrics *****
  epoch                   =        2.0
  eval_accuracy           =      0.853
  eval_loss               =     0.6948
  eval_runtime            = 0:10:29.27
  eval_samples_per_second =      1.589
  eval_steps_per_second   =      0.199
