In [None]:
import os
import torch


##################################################
#Device Setup
##################################################
DEVICE = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu"); print(f"Using device:{DEVICE}")
NUM_WORKERS = os.cpu_count(); print(f"Number of CPU cores: {NUM_WORKERS}")


##################################################
#Hyperparamaters
##################################################
BATCH_SIZE=32
IMG_SIZE=224; INPUT_CHANNELS=3
PATCH_SIZE=16
NUM_HEADS = 12; MLP_SIZE=3072
NUM_LAYERS=12
EMBEDDING_DROPOUT, ATTENTION_DROPOUT, MLP_DROPOUT = 0.1, 0.1, 0.1
NUM_CLASSES=200


##################################################
#Data Loading
##################################################
from torchvision.transforms import v2
TRAIN_TRANSFORMATIONS = v2.Compose([
    # Resize the image
    v2.Resize((IMG_SIZE, IMG_SIZE), interpolation=v2.InterpolationMode.BILINEAR, antialias=True),

    # Geometric Transformations
    v2.RandomHorizontalFlip(p=0.5), # Random horizontal flip
    v2.RandomRotation(degrees=15), # Random Rotation
    v2.RandomPerspective(distortion_scale=0.2, p=0.5), # Random Perspective

    # Photometric Transformations
    v2.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), # Random Color Jitter

    # Convert to Tensors 
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
print(f"Training Transformations: {TRAIN_TRANSFORMATIONS}")

TEST_TRANSFORMATIONS = v2.Compose([
    v2.Resize((IMG_SIZE, IMG_SIZE), interpolation=v2.InterpolationMode.BILINEAR, antialias=True),
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
print(f"Validation/Test Transformations: {TEST_TRANSFORMATIONS}")



from torchvision import datasets
from torch.utils.data import DataLoader, random_split, Subset
def create_train_val_dataloaders(root_dir:str, train_transformations:v2.Compose, val_transformations:v2.Compose, 
                       batch_size:int, num_workers:int, train_val_split:float=0.2):
    
    # Create two independent ImageFolder instances
    train_full_dataset = datasets.ImageFolder(root=root_dir, transform=train_transformations)
    val_full_dataset = datasets.ImageFolder(root=root_dir, transform=val_transformations)
    class_names = train_full_dataset.classes; print(f"Number of classes: {len(class_names)}"); print(f"Class names: {class_names}")

    total_samples = len(train_full_dataset); val_size = int(total_samples*train_val_split); train_size = total_samples-val_size
    print(f"Spitting dataset: Total={total_samples}, Train={train_size}, Validation={val_size}")
    
    g = torch.Generator().manual_seed(42)
    indices = torch.randperm(total_samples, generator=torch.Generator().manual_seed(42)).tolist()
    train_indices = indices[:train_size]; val_indices=indices[train_size:]

    train_subset = Subset(train_full_dataset, train_indices)
    val_subset = Subset(val_full_dataset,val_indices)

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return train_loader, val_loader, class_names


ROOT_DIR = "./archive/CUB_200_2011/images"
train_dataloader, val_dataloader, class_names = create_train_val_dataloaders(root_dir=ROOT_DIR, train_transformations=TRAIN_TRANSFORMATIONS, 
                                                                             val_transformations=TEST_TRANSFORMATIONS, batch_size=BATCH_SIZE,
                                                                             num_workers=NUM_WORKERS, train_val_split=0.25)

##################################################
# Visualise the datasets
##################################################
from helper_functions import visualize_dataset
visualize_dataset(dataset=train_dataloader, class_names=class_names, num_images=16, name="Training Dataset", cols=4)
visualize_dataset(dataset=val_dataloader, class_names=class_names, num_images=16, name="Validation Dataset", cols=4)


##################################################
# Vision Transformer Architecture
##################################################
from model import VisionTransformer
vit_model = VisionTransformer(image_size=IMG_SIZE, input_channels=INPUT_CHANNELS, patch_size=16, num_heads=12, mlp_size=3072, num_classes=NUM_CLASSES,
                              num_layers=12, embedding_dropout=0.1, attention_dropout=0.1, mlp_dropout=0.1)

from torchinfo import summary
image_batch, label_batch = next(iter(train_dataloader))
summary(model=vit_model, input_size=tuple(image_batch.shape),
        col_names=["input_size", "output_size", "num_params", "trainable"], row_settings=["var_names"])

##################################################
# Optimizer & Loss function
##################################################
optimizer = torch.optim.Adam(lr=3e-3, params=vit_model.parameters(), betas=(0.9,0.999), weight_decay=0.3)
loss_fn = torch.nn.CrossEntropyLoss()


##################################################
# Training
##################################################
from helper_functions import set_seeds
set_seeds()

from engine import train
history = train(model=vit_model, train_dataloader=train_dataloader, val_dataloader=val_dataloader, 
                loss_fn=loss_fn, optimizer=optimizer, epoch=10, device=DEVICE)


##################################################
# Plotting Training Curves
##################################################
from helper_functions import plot_loss_acc_curves
plot_loss_acc_curves(results=history)