In [1]:
# Check https://huggingface.co/blog/fine-tune-vit for more information
from datasets import load_dataset, load_metric
import random
from PIL import ImageDraw, ImageFont, Image
from transformers import ViTFeatureExtractor, ViTForImageClassification, Trainer, TrainingArguments
import torch
import numpy as np
import os

In [2]:
# select GPU before run anything, because Trainer will automatically use all GPU
os.environ["CUDA_VISIBLE_DEVICES"]="0, 1"

In [3]:
torch.cuda.device_count()

2

In [4]:
ds = load_dataset('beans')

Using custom data configuration default
Reusing dataset beans (/root/.cache/huggingface/datasets/beans/default/0.0.0/90c755fb6db1c0ccdad02e897a37969dbf070bed3755d4391e269ff70642d791)


  0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
# convert image into 1 * 3 * 224 * 224
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

In [6]:
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['labels'] = example_batch['labels']
    return inputs

In [7]:
prepared_ds = ds.with_transform(transform)

In [8]:
# dataset
prepared_ds['train'][0:2]

{'pixel_values': tensor([[[[-0.4510, -0.4745, -0.4902,  ...,  0.4824,  0.4745,  0.3490],
          [-0.4039, -0.4510, -0.4745,  ...,  0.3176,  0.3333,  0.2863],
          [-0.2627, -0.2863, -0.3020,  ...,  0.1843,  0.2471,  0.2314],
          ...,
          [ 0.6706,  0.6706,  0.6706,  ...,  0.1216,  0.0980,  0.0353],
          [ 0.6627,  0.6627,  0.6627,  ...,  0.1373,  0.1059,  0.0510],
          [ 0.6078,  0.6392,  0.6549,  ...,  0.1294,  0.1059, -0.0039]],

         [[-0.6549, -0.6941, -0.6941,  ...,  0.1765,  0.1686,  0.0353],
          [-0.4745, -0.5765, -0.6471,  ...,  0.0039,  0.0196, -0.0275],
          [-0.2000, -0.2706, -0.3412,  ..., -0.1529, -0.0902, -0.1059],
          ...,
          [ 0.1843,  0.1843,  0.1686,  ..., -0.0353, -0.0039, -0.0275],
          [ 0.1608,  0.1686,  0.1608,  ..., -0.1216, -0.1137, -0.1059],
          [ 0.0980,  0.1373,  0.1608,  ..., -0.1529, -0.1373, -0.1765]],

         [[-0.8588, -0.8667, -0.8667,  ...,  0.0588,  0.0824, -0.0510],
          [-0

In [9]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

In [10]:
metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

In [11]:
labels = ds['train'].features['labels'].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
# The other arguments are in the model.config
model.config.id2label

{'0': 'angular_leaf_spot', '1': 'bean_rust', '2': 'healthy'}

In [13]:
# epoch 50
# evaluate and save per 64 step
# can change to epoch if save_strategy="epoch" and evaluation_strategy="epoch"
# save all checkpoints in ./cp
training_args = TrainingArguments(
  output_dir="./cp",
  per_device_train_batch_size=64,
  dataloader_num_workers=16,
  evaluation_strategy="steps",
  num_train_epochs=50,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=20,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

In [14]:
# as data is already prepared, extractor is not needed
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
#     tokenizer=feature_extractor,
)

Using amp half precision backend


In [15]:
train_results = trainer.train()

***** Running training *****
  Num examples = 1034
  Num Epochs = 50
  Instantaneous batch size per device = 64
  Total train batch size (w. parallel, distributed & accumulation) = 128
  Gradient Accumulation steps = 1
  Total optimization steps = 450


Step,Training Loss,Validation Loss,Accuracy
100,0.0113,0.065817,0.984962
200,0.0054,0.071564,0.984962
300,0.0037,0.076429,0.984962
400,0.0031,0.078705,0.984962


***** Running Evaluation *****
  Num examples = 133
  Batch size = 16
Saving model checkpoint to ./cp/checkpoint-100
Configuration saved in ./cp/checkpoint-100/config.json
Model weights saved in ./cp/checkpoint-100/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 133
  Batch size = 16
Saving model checkpoint to ./cp/checkpoint-200
Configuration saved in ./cp/checkpoint-200/config.json
Model weights saved in ./cp/checkpoint-200/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 133
  Batch size = 16
Saving model checkpoint to ./cp/checkpoint-300
Configuration saved in ./cp/checkpoint-300/config.json
Model weights saved in ./cp/checkpoint-300/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 133
  Batch size = 16
Saving model checkpoint to ./cp/checkpoint-400
Configuration saved in ./cp/checkpoint-400/config.json
Model weights saved in ./cp/checkpoint-400/pytorch_model.bin


Training completed. Do not forget to share your model on huggingfa

In [16]:
trainer.save_model()

Saving model checkpoint to ./cp
Configuration saved in ./cp/config.json
Model weights saved in ./cp/pytorch_model.bin


In [17]:
trainer.log_metrics("train", train_results.metrics)

***** train metrics *****
  epoch                    =         50.0
  total_flos               = 3731224472GF
  train_loss               =       0.0279
  train_runtime            =   0:04:47.10
  train_samples_per_second =      180.072
  train_steps_per_second   =        1.567


In [18]:
# save as all_results.json and train_results.json
trainer.save_metrics("train", train_results.metrics)

In [19]:
# save as trainer_state.json contains all arguments in trainer
trainer.save_state()

In [21]:
# load saved parameters
model_final = ViTForImageClassification.from_pretrained("./cp")
model_first_100 = ViTForImageClassification.from_pretrained("./cp/checkpoint-100/")

loading configuration file ./cp/config.json
Model config ViTConfig {
  "_name_or_path": "google/vit-base-patch16-224-in21k",
  "architectures": [
    "ViTForImageClassification"
  ],
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "id2label": {
    "0": "angular_leaf_spot",
    "1": "bean_rust",
    "2": "healthy"
  },
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "angular_leaf_spot": "0",
    "bean_rust": "1",
    "healthy": "2"
  },
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "problem_type": "single_label_classification",
  "qkv_bias": true,
  "torch_dtype": "float32",
  "transformers_version": "4.19.2"
}

loading weights file ./cp/pytorch_model.bin
All model checkpoint weights were used when initializing ViTForImageClassification.

All 

In [22]:
pretrained_model = torch.load("./cp/pytorch_model.bin", map_location='cpu')

In [24]:
# mannually load parameters
# This one is for our own designed models
print(model_final.load_state_dict(pretrained_model, strict=False))

<All keys matched successfully>
