In [3]:

from PIL import Image
from torchinfo import summary
import torch
import os
import warnings
warnings.filterwarnings("ignore")
from typing import Tuple

from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split

from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

from datasets import load_dataset


In [None]:

# pip install datasets


In [None]:
# !pip install torchvision
# !pip install torchinfo
# !pip install -q git+https://github.com/huggingface/transformers.git


In [4]:

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
dataset_folder_name = 'drive/MyDrive/short_dataset'
dataset = load_dataset("imagefolder", data_dir=dataset_folder_name)


Resolving data files:   0%|          | 0/500 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/250 [00:00<?, ?it/s]

In [6]:
dataset


DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 500
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 250
    })
})

In [7]:
from transformers import ViTImageProcessor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
processor = ViTImageProcessor.from_pretrained(model_name_or_path)

In [8]:
def transform(example_batch):
    inputs = processor([x for x in example_batch['image']], return_tensors='pt')

    inputs['label'] = example_batch['label']
    return inputs

In [9]:
prepared_ds = dataset.with_transform(transform)


In [10]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

In [None]:

# pip install evaluate

In [12]:
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import f1_score

import evaluate

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    predictions = np.argmax(predictions, axis=1)

    result_accuracy = accuracy.compute(predictions=predictions, references=labels)

    result = {
             'accuracy': np.mean([result_accuracy['accuracy']]),
             'kappa': np.mean([cohen_kappa_score(labels, predictions)]),
             'f1': np.mean([f1_score(labels, predictions, average='weighted')])
             }

    return result


In [13]:
from transformers import ViTForImageClassification

labels = dataset['train'].features['label'].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels)
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# pip install accelerate -U


In [None]:
# pip install transformers -U

In [17]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./vit-base",
    per_device_train_batch_size=16,
    evaluation_strategy="steps",
    num_train_epochs=1,
    fp16=True,
    save_steps=100,
    eval_steps=100,
    logging_steps=10,
    learning_rate=2e-4,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to='tensorboard',
    load_best_model_at_end=True,
)

In [18]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["test"],
    tokenizer=processor,
)

In [19]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

Step,Training Loss,Validation Loss


***** train metrics *****
  epoch                    =        1.0
  total_flos               = 36085989GF
  train_loss               =     1.5689
  train_runtime            = 0:04:31.73
  train_samples_per_second =       1.84
  train_steps_per_second   =      0.118


In [20]:
metrics = trainer.evaluate(prepared_ds['test'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  epoch                   =        1.0
  eval_accuracy           =       0.32
  eval_f1                 =     0.2555
  eval_kappa              =       0.15
  eval_loss               =     1.5599
  eval_runtime            = 0:02:42.49
  eval_samples_per_second =      1.538
  eval_steps_per_second   =      0.197
