In [1]:
import numpy as np
import pandas as pd
import torch
from datasets import load_dataset, load_metric
from transformers import ViTFeatureExtractor, ViTForImageClassification, TrainingArguments, Trainer

## Auxiliary Functions

In [2]:
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
    inputs['labels'] = example_batch['labels']
    return inputs

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

  metric = load_metric("accuracy")


## Create Image Dataset

1. Put all images under the folder '/imgs'
2. Put training images in '/imgs/train/'
3. Put meta information of training images in '/imgs/train/metadata.csv'. The first column should be the file name corresponding to each image. The following columns could be any additional information, such as label.
4. Put validation images in '/imgs/val/'
5. Put meta information of validation images in '/imgs/val/metadata.csv'. The first column should be the file name corresponding to each image. The following columns could be any additional information, such as label.

In [3]:
dataset = load_dataset("imagefolder", data_dir="imgs/")
dataset

Resolving data files:   0%|          | 0/145 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/34 [00:00<?, ?it/s]

Using custom data configuration default-f50966761ba7237c


Downloading and preparing dataset imagefolder/default to C:/Users/binha/.cache/huggingface/datasets/imagefolder/default-f50966761ba7237c/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f...


Downloading data files:   0%|          | 0/147 [00:00<?, ?it/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Downloading data files:   0%|          | 0/36 [00:00<?, ?it/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Dataset imagefolder downloaded and prepared to C:/Users/binha/.cache/huggingface/datasets/imagefolder/default-f50966761ba7237c/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['image', 'labels'],
        num_rows: 144
    })
    validation: Dataset({
        features: ['image', 'labels'],
        num_rows: 33
    })
})

In [4]:
prepared_ds = dataset.with_transform(transform)

## Load Model

In [5]:
model_name_or_path = 'google/vit-base-patch16-224-in21k'
model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=2
)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
training_args = TrainingArguments(
    output_dir="./bufflegrass-finetune",
    per_device_train_batch_size=2,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    num_train_epochs=4,
    fp16=True,
    learning_rate=2e-4,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    load_best_model_at_end=True,
)

In [11]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    tokenizer=feature_extractor,
)

In [12]:
train_results = trainer.train()



Epoch,Training Loss,Validation Loss,Accuracy
1,0.7157,0.699078,0.575758
2,0.672,0.713527,0.454545
3,0.6562,0.697465,0.484848
4,0.6773,0.689586,0.515152
