In [1]:
import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from datasets import load_dataset, load_metric
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer

## Auxiliary Functions

In [2]:
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTImageProcessor.from_pretrained(model_name_or_path)
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
    inputs['labels'] = example_batch['labels']
    return inputs

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

  metric = load_metric("accuracy")


## Create Image Dataset

1. Put all images under the folder '/imgs'
2. Put training images in '/imgs/train/'
3. Put meta information of training images in '/imgs/train/metadata.csv'. The first column should be the file name corresponding to each image. The following columns could be any additional information, such as label.
4. Put validation images in '/imgs/val/'
5. Put meta information of validation images in '/imgs/val/metadata.csv'. The first column should be the file name corresponding to each image. The following columns could be any additional information, such as label.

In [3]:
dataset = load_dataset("imagefolder", data_dir="planet-imgs-original/split1/")
prepared_ds = dataset.with_transform(transform)
prepared_ds

Resolving data files:   0%|          | 0/125 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/40 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/32 [00:00<?, ?it/s]

Using custom data configuration default-d65ce7d20941fc0b
Found cached dataset imagefolder (C:/Users/binha/.cache/huggingface/datasets/imagefolder/default-d65ce7d20941fc0b/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['image', 'labels'],
        num_rows: 124
    })
    test: Dataset({
        features: ['image', 'labels'],
        num_rows: 39
    })
    validation: Dataset({
        features: ['image', 'labels'],
        num_rows: 31
    })
})

## Load Model

In [4]:
model_name_or_path = 'google/vit-base-patch16-224-in21k'
model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=2
)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
training_args = TrainingArguments(
    output_dir="./bufflegrass-finetune",
    per_device_train_batch_size=2,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    num_train_epochs=10,
    fp16=True,
    learning_rate=1e-5,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    load_best_model_at_end=True,
)

In [6]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    tokenizer=feature_extractor,
)

In [7]:
train_results = trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mbinhan96816[0m ([33mbeanham[0m). Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss,Accuracy
1,0.6927,0.665039,0.612903
2,0.6426,0.652848,0.645161
3,0.5547,0.650611,0.677419
4,0.4725,0.646477,0.677419
5,0.3791,0.70594,0.645161
6,0.3202,0.723196,0.709677
7,0.2705,0.778887,0.677419
8,0.2385,0.845501,0.645161
9,0.2183,0.857188,0.677419
10,0.2044,0.860161,0.677419


In [9]:
metrics = trainer.evaluate(prepared_ds['validation'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  epoch                   =       10.0
  eval_accuracy           =     0.6774
  eval_loss               =     0.6465
  eval_runtime            = 0:00:00.54
  eval_samples_per_second =     57.372
  eval_steps_per_second   =      7.403


In [8]:
metrics = trainer.evaluate(prepared_ds['test'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  epoch                   =       10.0
  eval_accuracy           =     0.5897
  eval_loss               =     0.6473
  eval_runtime            = 0:00:00.27
  eval_samples_per_second =    142.304
  eval_steps_per_second   =     18.244
