In [1]:
import os
import cv2
import numpy as np
import random
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, ToTensor
from transformers import ViTFeatureExtractor, BitForImageClassification, TrainingArguments, Trainer, ViTForImageClassification
import evaluate
from PIL import Image
import matplotlib.pyplot as plt

from transformers import ViTImageProcessor

from transformers import EarlyStoppingCallback

  from .autonotebook import tqdm as notebook_tqdm


In [21]:
path_to_data = os.path.abspath('../../data2')
train_dir = os.path.join(path_to_data, "train")
test_dir = os.path.join(path_to_data, "test")
val_dir = os.path.join(path_to_data, "validation")

data_dir = "temp"

# Load the model
# Hugging Face: Google / vit-base-patch16-224-in21K
model_id = "google/vit-base-patch16-224-in21k"
image_processor = ViTImageProcessor.from_pretrained(model_id)

In [44]:
# Custom transformation pipeline for the dataset

def transform(image):
    inputs = image_processor(image, return_tensors="pt")
    return inputs["pixel_values"].squeeze(0) # remove batch for DataLoader

In [45]:
# Load datasets
train_dataset = ImageFolder(train_dir, transform=transform)
test_dataset = ImageFolder(test_dir, transform=transform)
val_dataset = ImageFolder(val_dir, transform=transform)

print(train_dataset)

Dataset ImageFolder
    Number of datapoints: 8580
    Root location: c:\Users\siddi\Documents\UAB Documents\Semester 2\CS 685\Project\cs685\data2\train
    StandardTransform
Transform: <function transform at 0x00000234B11579C0>


In [46]:
def collate_fn(batch):
    images, labels = zip(*batch)
    return {
        "pixel_values": torch.stack(images),
        "labels": torch.tensor(labels)
    }

In [47]:
metric = evaluate.load("accuracy")

def compute_metrics(p):
    predictions = np.argmax(p.predictions, axis=1)
    references = p.label_ids
    return metric.compute(predictions=predictions, references=references)

In [48]:
# Prepare the model
num_classes = len(train_dataset.classes)
print(num_classes)

15


In [49]:
model = ViTForImageClassification.from_pretrained(
    model_id,
    num_labels=num_classes
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [50]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermed

In [51]:
# Define train arguments:
# 200 epochs with early stopping

training_args = TrainingArguments(
    output_dir = data_dir + "/vit_custom", 
    per_device_eval_batch_size=16,
    per_device_train_batch_size=16,
    eval_strategy="steps",
    num_train_epochs=200,
    save_steps=100,
    eval_steps=100,
    logging_steps=10,
    learning_rate=2e-4,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    load_best_model_at_end=True
)

In [52]:
# Define early stopping callback
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=10, # Stop after no improvement for 10 evaluation steps,
    early_stopping_threshold=0.0 # improvement threshold (use 0.0 for exact match) 
)

In [53]:
# Define the Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    processing_class=image_processor, # Use the ViTImageProcessor
    callbacks=[early_stopping_callback]
)

In [54]:
# Train the model
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

Step,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
# Evaluate the model
metrics = trainer.evaluate(test_dataset)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)