In [1]:
import os
import cv2
import numpy as np
import random
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, ToTensor
from transformers import ViTFeatureExtractor, BitForImageClassification, TrainingArguments, Trainer, ViTForImageClassification
from PIL import Image
from tqdm.auto import tqdm
from accelerate import Accelerator

import evaluate
import matplotlib.pyplot as plt
import torch.nn.functional as F

from transformers import ViTImageProcessor

from transformers import EarlyStoppingCallback

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
path_to_data = os.path.abspath('../../data2')
train_dir = os.path.join(path_to_data, "train")
test_dir = os.path.join(path_to_data, "test")
val_dir = os.path.join(path_to_data, "validation")

data_dir = "temp"

# Load the model
# Hugging Face: Google / vit-base-patch16-224-in21K
model_id = "google/vit-base-patch16-224-in21k"
image_processor = ViTImageProcessor.from_pretrained(model_id)

In [3]:
# Custom transformation pipeline for the dataset

def transform(image):
    inputs = image_processor(image, return_tensors="pt")
    return inputs["pixel_values"].squeeze(0) # remove batch for DataLoader

In [4]:
accelerator = Accelerator()
device = accelerator.device

In [5]:
# Load datasets
train_dataset = ImageFolder(train_dir, transform=transform)
test_dataset = ImageFolder(test_dir, transform=transform)
val_dataset = ImageFolder(val_dir, transform=transform)

print(train_dataset)

Dataset ImageFolder
    Number of datapoints: 7953
    Root location: c:\Users\siddi\Documents\UAB Documents\Semester 2\CS 685\Project\cs685\data2\train
    StandardTransform
Transform: <function transform at 0x000002CBA4563D80>


In [6]:
def collate_fn(batch):
    images, labels = zip(*batch)
    return {
        "pixel_values": torch.stack(images),
        "labels": torch.tensor(labels)
    }

In [7]:
metric = evaluate.load("accuracy")

def compute_metrics(p):
    predictions = np.argmax(p.predictions, axis=1)
    references = p.label_ids
    return metric.compute(predictions=predictions, references=references)

In [8]:
# Prepare the model
num_classes = len(train_dataset.classes)
print(num_classes)

15


In [9]:
model = ViTForImageClassification.from_pretrained(
    model_id,
    num_labels=num_classes
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
optimizer = AdamW(model.parameters(), lr=2e-4)

In [11]:
# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn, num_workers=4)

model, optimizer, train_loader, val_loader = accelerator.prepare(model, optimizer, train_loader, val_loader)

In [None]:
# Custom training loop
epochs = 1
eval_steps = 100
early_stopping_patience = 10
best_val_acc = 0
patience_counter = 0

def evaluate(model, dataloader):
    model.eval()
    total, correct = 0, 0
    for batch in dataloader:
        with torch.no_grad():
            outputs = model(pixel_values=batch["pixel_values"], labels=batch["labels"])
            preds = torch.argmax(outputs.logits, dim=-1)
            correct += (preds == batch["labels"]).sum().item()
            total += len(batch["labels"])
    return correct / total

for epoch in range(epochs):
    model.train()
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
    for step, batch in enumerate(progress_bar):
        outputs = model(pixel_values=batch["pixel_values"], labels=batch["labels"])
        loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()

        if step % eval_steps == 0:
            val_acc = evaluate(model, val_loader)
            accelerator.print(f"Step {step}: Validation Accuracy = {val_acc:.4f}")

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                patience_counter = 0
                accelerator.wait_for_everyone()
                unwrapped_model = accelerator.unwrap_model(model)
                unwrapped_model.save_pretrained(data_dir + "/vit_custom")
            else:
                patience_counter += 1

            if patience_counter >= early_stopping_patience:
                accelerator.print("Early stopping triggered.")
                break

    if patience_counter >= early_stopping_patience:
        break


Epoch 1/1:   0%|          | 0/498 [00:00<?, ?it/s]

In [None]:
# Load best model and evaluate
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained(data_dir + "/vit_custom")
model = accelerator.prepare(model)
test_acc = evaluate(model, test_loader)
print(f"Test Accuracy: {test_acc:.4f}")


In [None]:
# # Define train arguments:
# # 200 epochs with early stopping

# training_args = TrainingArguments(
#     output_dir = data_dir + "/vit_custom", 
#     per_device_eval_batch_size=16,
#     per_device_train_batch_size=16,
#     eval_strategy="steps",
#     num_train_epochs=200,
#     save_steps=100,
#     eval_steps=100,
#     logging_steps=10,
#     learning_rate=2e-4,
#     save_total_limit=2,
#     remove_unused_columns=False,
#     push_to_hub=False,
#     load_best_model_at_end=True
# )

In [None]:
# # Define early stopping callback
# early_stopping_callback = EarlyStoppingCallback(
#     early_stopping_patience=10, # Stop after no improvement for 10 evaluation steps,
#     early_stopping_threshold=0.0 # improvement threshold (use 0.0 for exact match) 
# )

In [None]:
# # Define the Trainer

# trainer = Trainer(
#     model=model,
#     args=training_args,
#     data_collator=collate_fn,
#     compute_metrics=compute_metrics,
#     train_dataset=train_dataset,
#     eval_dataset=val_dataset,
#     processing_class=image_processor, # Use the ViTImageProcessor
#     callbacks=[early_stopping_callback]
# )

In [None]:
# # Train the model
# train_results = trainer.train()
# trainer.save_model()
# trainer.log_metrics("train", train_results.metrics)
# trainer.save_metrics("train", train_results.metrics)
# trainer.save_state()

Step,Training Loss,Validation Loss,Accuracy
100,1.1518,1.174051,0.709859


KeyboardInterrupt: 

In [None]:
# # Evaluate the model
# metrics = trainer.evaluate(test_dataset)
# trainer.log_metrics("eval", metrics)
# trainer.save_metrics("eval", metrics)