In [1]:
import torch
import numpy as np
from datasets import load_dataset, load_metric
from transformers import ViTFeatureExtractor, ViTForImageClassification, TrainingArguments, Trainer
from torch.utils.data import DataLoader
from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor)

# load cifar10
train_ds, test_ds = load_dataset('cifar100', split=['train', 'test'])

Reusing dataset cifar100 (/home/ecbm4040/.cache/huggingface/datasets/cifar100/cifar100/1.0.0/f365c8b725c23e8f0f8d725c3641234d9331cd2f62919d1381d1baa5b3ba3142)


  0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
# split trainset to train and val
splits = train_ds.train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']

# create label to index map
id2label = {id:label for id, label in enumerate(train_ds.features['fine_label'].names)}
label2id = {label:id for id,label in id2label.items()}

In [3]:
# load feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_train_transforms = Compose(
        [
            RandomResizedCrop(feature_extractor.size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )
_val_transforms = Compose(
        [
            Resize(feature_extractor.size),
            CenterCrop(feature_extractor.size),
            ToTensor(),
            normalize,
        ]
    )

def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['img']]
    return examples

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['img']]
    return examples

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["fine_label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

# Set the transforms
train_ds.set_transform(train_transforms)
val_ds.set_transform(val_transforms)
test_ds.set_transform(val_transforms)

model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                  num_labels=100,
                                                  id2label=id2label,
                                                  label2id=label2id)


metric_name = "accuracy"
args = TrainingArguments(
    f"test-cifar-100",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    logging_dir='logs',
    remove_unused_columns=False,
)

metric = load_metric(metric_name)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

trainer = Trainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=feature_extractor,
)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
trainer.train()

outputs = trainer.predict(test_ds)
print(outputs.metrics)

***** Running training *****
  Num examples = 45000
  Num Epochs = 3
  Instantaneous batch size per device = 64
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 1
  Total optimization steps = 2112


Epoch,Training Loss,Validation Loss,Accuracy
1,2.843,0.923213,0.886
2,1.319,0.495538,0.9064
3,0.7673,0.399715,0.9182


***** Running Evaluation *****
  Num examples = 5000
  Batch size = 64
Saving model checkpoint to test-cifar-100/checkpoint-704
Configuration saved in test-cifar-100/checkpoint-704/config.json
Model weights saved in test-cifar-100/checkpoint-704/pytorch_model.bin
Feature extractor saved in test-cifar-100/checkpoint-704/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 5000
  Batch size = 64
Saving model checkpoint to test-cifar-100/checkpoint-1408
Configuration saved in test-cifar-100/checkpoint-1408/config.json
Model weights saved in test-cifar-100/checkpoint-1408/pytorch_model.bin
Feature extractor saved in test-cifar-100/checkpoint-1408/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 5000
  Batch size = 64
Saving model checkpoint to test-cifar-100/checkpoint-2112
Configuration saved in test-cifar-100/checkpoint-2112/config.json
Model weights saved in test-cifar-100/checkpoint-2112/pytorch_model.bin
Feature extractor saved in test-cifar

{'test_loss': 0.40992000699043274, 'test_accuracy': 0.9181, 'test_runtime': 153.011, 'test_samples_per_second': 65.355, 'test_steps_per_second': 1.026}
