In [1]:
from transformers import ViTFeatureExtractor, ViTForImageClassification, TrainingArguments, Trainer
from datasets import load_dataset
from torch.utils.data import default_collate
from evaluate import load
from pathlib import Path
import numpy as np

In [2]:
def compute_metrics(p):
    metric = load("accuracy")
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)


def transforms(example_batch):
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
    inputs['labels'] = example_batch['label']
    return inputs

In [3]:
# Input
model_name = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)

path_train = Path.home() / 'Desktop/dogs-vs-cats/train'
dataset_train = load_dataset("imagefolder", data_dir=str(path_train), split='train')
splits = dataset_train.train_test_split(test_size=0.2)
dataset_test_valid = splits['test'].train_test_split(test_size=0.5)

# Set the train and validation data
train_data, val_data = splits['train'], dataset_test_valid['train']
train_data.set_transform(transforms)
val_data.set_transform(transforms)

# Set the test data
test_data = dataset_test_valid['test']
test_data.set_transform(transforms)

Resolving data files:   0%|          | 0/25000 [00:00<?, ?it/s]

Using custom data configuration default-d5564b158da7eecc
Found cached dataset imagefolder (C:/Users/Kevin/.cache/huggingface/datasets/imagefolder/default-d5564b158da7eecc/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f)


In [4]:
# Model
labels = {'cat': 0, 'dog': 1}
model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

# Train
training_args = TrainingArguments(
    output_dir="./vit_dog_cat",
    per_device_train_batch_size=32,
    evaluation_strategy="steps",
    num_train_epochs=3,
    fp16=True,
    save_steps=100,
    eval_steps=100,
    logging_steps=10,
    learning_rate=2e-4,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=default_collate,
    compute_metrics=compute_metrics,
    train_dataset=train_data,
    eval_dataset=val_data,
    tokenizer=feature_extractor,
)

train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Using cuda_amp half precision ba

Step,Training Loss,Validation Loss,Accuracy
100,0.0423,0.094767,0.9696
200,0.1204,0.052264,0.9864
300,0.0315,0.055145,0.984
400,0.0726,0.081907,0.9688
500,0.0736,0.06841,0.9776
600,0.0751,0.047189,0.9868
700,0.0104,0.045716,0.9864
800,0.0253,0.046584,0.9872
900,0.0627,0.046796,0.9852
1000,0.0126,0.062417,0.9828


***** Running Evaluation *****
  Num examples = 2500
  Batch size = 8
Saving model checkpoint to ./vit_dog_cat\checkpoint-100
Configuration saved in ./vit_dog_cat\checkpoint-100\config.json
Model weights saved in ./vit_dog_cat\checkpoint-100\pytorch_model.bin
Feature extractor saved in ./vit_dog_cat\checkpoint-100\preprocessor_config.json
Deleting older checkpoint [vit_dog_cat\checkpoint-1000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 2500
  Batch size = 8
Saving model checkpoint to ./vit_dog_cat\checkpoint-200
Configuration saved in ./vit_dog_cat\checkpoint-200\config.json
Model weights saved in ./vit_dog_cat\checkpoint-200\pytorch_model.bin
Feature extractor saved in ./vit_dog_cat\checkpoint-200\preprocessor_config.json
Deleting older checkpoint [vit_dog_cat\checkpoint-2000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 2500
  Batch size = 8
Saving model checkpoint to ./vit_dog_cat\checkpoint-300
Configuration saved in

***** train metrics *****
  epoch                    =          3.0
  total_flos               = 4330202356GF
  train_loss               =       0.0331
  train_runtime            =   0:14:13.74
  train_samples_per_second =       70.279
  train_steps_per_second   =        2.196
