In [1]:
import torch
import torchvision
from torch import nn
from torchvision.models import VisionTransformer
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from helper_functions import set_seeds
from going_modular.going_modular import engine
from helper_functions import plot_loss_curves

In [2]:
# Define the EarlyStopping callback class
class EarlyStopping:
    def __init__(self, patience=5, verbose=False):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_val_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_val_loss is None:
            self.best_val_loss = val_loss
        elif val_loss > self.best_val_loss:
            self.counter += 1
            if self.verbose:
                print(f'INFO: Validation loss did not improve for {self.counter} epoch(s).')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_val_loss = val_loss
            self.counter = 0

In [3]:
# Function to create data loaders
def create_dataloaders(train_dir: str, test_dir: str, transform: transforms.Compose, batch_size: int, num_workers: int):
    train_data = datasets.ImageFolder(train_dir, transform=transform)
    test_data = datasets.ImageFolder(test_dir, transform=transform)

    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    return train_dataloader, test_dataloader, train_data.classes


In [4]:
# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"


In [5]:
# Setup directory paths to train and test images
train_dir = '/Users/sandundesilva/Downloads/Datasets/train'
test_dir = '/Users/sandundesilva/Downloads/Datasets/test'

In [6]:
# Define class names
class_names = ['Blepharitis', 'Conjunctivitis', 'Entropion', 'EyelidTumor', 'HealthyEye', 'Mastopathy', 'NuclearSclerosis', 'PigmentedKeratitis']

In [7]:
# Define transforms
pretrained_vit_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [8]:
# Create data loaders
train_dataloader, test_dataloader, _ = create_dataloaders(train_dir=train_dir,
                                                          test_dir=test_dir,
                                                          transform=pretrained_vit_transforms,
                                                          batch_size=32,
                                                          num_workers=4)

In [9]:
# Initialize ViT model
pretrained_vit = torchvision.models.vit_b_16(pretrained=True).to(device)



In [10]:
# Freeze base parameters
for parameter in pretrained_vit.parameters():
    parameter.requires_grad = False

In [11]:
# Modify classifier head
pretrained_vit.heads = nn.Linear(in_features=768, out_features=len(class_names)).to(device)


In [12]:
# Create optimizer and loss function
optimizer = torch.optim.Adam(params=pretrained_vit.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

In [13]:
# Define the EarlyStopping callback
early_stop = EarlyStopping(patience=3, verbose=True)

In [14]:
def train_with_early_stopping(model, train_loader, val_loader, optimizer, loss_fn, epochs, device, patience):
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        correct_train = 0
        total_train = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * images.size(0)
            
            _, predicted = outputs.max(1)
            correct_train += predicted.eq(labels).sum().item()
            total_train += labels.size(0)
        
        train_loss /= len(train_loader.dataset)
        train_accuracy = correct_train / total_train

        # Validation
        model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = loss_fn(outputs, labels)
                val_loss += loss.item() * images.size(0)
                
                _, predicted = outputs.max(1)
                correct_val += predicted.eq(labels).sum().item()
                total_val += labels.size(0)
            
            val_loss /= len(val_loader.dataset)
            val_accuracy = correct_val / total_val
        
        print(f'Epoch {epoch+1}/{epochs}, Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')

        # Check for improvement in validation loss
        if early_stopping is not None:
            early_stopping(val_loss)
            if early_stopping.early_stop:
                print("Early stopping")
                break


In [15]:

# Usage
train_with_early_stopping(model=pretrained_vit,
                          train_loader=train_dataloader,
                          val_loader=test_dataloader,
                          optimizer=optimizer,
                          loss_fn=loss_fn,
                          epochs=10,
                          device=device,
                          patience=3)


Epoch 1/10, Training Loss: 1.1473, Training Accuracy: 0.5560, Validation Loss: 1.0213, Validation Accuracy: 0.6317
Epoch 2/10, Training Loss: 0.8793, Training Accuracy: 0.6643, Validation Loss: 0.9585, Validation Accuracy: 0.6335
Epoch 3/10, Training Loss: 0.7914, Training Accuracy: 0.6936, Validation Loss: 0.9225, Validation Accuracy: 0.6762
Epoch 4/10, Training Loss: 0.7338, Training Accuracy: 0.7155, Validation Loss: 0.9007, Validation Accuracy: 0.6619
Epoch 5/10, Training Loss: 0.6956, Training Accuracy: 0.7310, Validation Loss: 0.9020, Validation Accuracy: 0.6459
INFO: Validation loss did not improve for 1 epoch(s).
Epoch 6/10, Training Loss: 0.6649, Training Accuracy: 0.7432, Validation Loss: 0.8776, Validation Accuracy: 0.6495
Epoch 7/10, Training Loss: 0.6388, Training Accuracy: 0.7575, Validation Loss: 0.8945, Validation Accuracy: 0.6495
INFO: Validation loss did not improve for 1 epoch(s).
Epoch 8/10, Training Loss: 0.6193, Training Accuracy: 0.7601, Validation Loss: 0.8912, 