In [1]:
from datasets import load_dataset
from transformers.utils.dummy_vision_objects import ImageGPTFeatureExtractor
import os
from PIL import ImageDraw, ImageFont, Image
import torch
import numpy as np
from datasets import load_metric
from transformers import ViTForImageClassification, TrainingArguments, Trainer, ViTFeatureExtractor
import cv2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['label'] = example_batch['label']
    return inputs


dataset = load_dataset("imagefolder", data_files={"train": ["./objects/train/**", "./objects/test/**",], "val": "./objects/val/**"})

prepared_ds = dataset.with_transform(transform)
labels = dataset['train'].features['label'].names


Resolving data files: 100%|██████████| 35011/35011 [00:00<00:00, 37698.43it/s] 
Resolving data files: 100%|██████████| 3970/3970 [00:00<00:00, 9430.28it/s] 
Using custom data configuration default-146ea55b81da53a9


Downloading and preparing dataset imagefolder/default to /home/ztchen/.cache/huggingface/datasets/imagefolder/default-146ea55b81da53a9/0.0.0/48efdc62d40223daee675ca093d163bcb6cb0b7d7f93eb25aebf5edca72dc597...


Downloading data files #0:   0%|          | 0/2189 [00:00<?, ?obj/s]
[A


[A[A[A

[A[A











[A[A[A[A[A[A[A[A[A[A[A[A






[A[A[A[A[A[A[A



Downloading data files #0: 100%|██████████| 2189/2189 [00:00<00:00, 13895.00obj/s]













[A[A[A[A[A[A[A[A[A[A[A[A[A










Downloading data files #4: 100%|██████████| 2188/2188 [00:00<00:00, 19368.61obj/s]
Downloading data files #2: 100%|██████████| 2189/2189 [00:00<00:00, 13915.35obj/s]










Downloading data files #5: 100%|██████████| 2188/2188 [00:00<00:00, 37510.52obj/s]
Downloading data files #3: 100%|██████████| 2188/2188 [00:00<00:00, 19228.40obj/s]
Downloading data files #1: 100%|██████████| 2189/2189 [00:00<00:00, 19639.97obj/s]








[A[A[A[A[A[A[A[A




Downloading data files #13: 100%|██████████| 2188/2188 [00:00<00:00, 22384.46obj/s]
Downloading data files #8: 100%|██████████| 2188/2188 [00:00<00:00, 20308.30obj/s]






Downloading data files #14: 100%|██████████| 21

Dataset imagefolder downloaded and prepared to /home/ztchen/.cache/huggingface/datasets/imagefolder/default-146ea55b81da53a9/0.0.0/48efdc62d40223daee675ca093d163bcb6cb0b7d7f93eb25aebf5edca72dc597. Subsequent calls will reuse this data.


100%|██████████| 2/2 [00:00<00:00,  6.22it/s]


In [3]:
# pre-trained models
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# trainer

training_args = TrainingArguments(
  output_dir="./vit-base-beans-r9",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  fp16=True,
  save_steps=500,
  eval_steps=500,
  logging_steps=50,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["val"],
    tokenizer=feature_extractor,
)

train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

metrics = trainer.evaluate(prepared_ds['val'])
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)

In [None]:
# inference

model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
model = ViTForImageClassification.from_pretrained('./vit-base-beans-demo-v5/')
model.cuda()

# dataset = load_dataset("imagefolder", data_files={"train": "./objects/train/**", "test": "./objects/test/**", "val": "./objects/val/**"})


In [None]:
dir = './objects/test/negative/'
imgs = [cv2.imread(dir + '/' + i) for i in os.listdir(dir)[:300]]
imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs]

inputs = feature_extractor(imgs, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].cuda()

In [None]:
with torch.no_grad():
    logits = model(**inputs).logits
for i in logits.argmax(-1).cpu():
    print(i == 1)