# Dependencies and Configs

In [None]:
pip install torch torchvision pyyaml

**Loading Configurations from the YAML file. [You may modify the file as your favour.]**

In [None]:
import yaml

with open('../config.yaml', 'r') as config_file:
    config = yaml.safe_load(config_file)

# Preparing Train and Validation Datasets + Pre Processing

In [None]:
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from pathlib import Path

**Define Transforms**

In [None]:
import random

In [None]:
# Some of the below transformations are for data augmentation purposes
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(size=224),  # Randomly crop the image to 224x224, with random scale and aspect ratio
    transforms.RandomRotation(degrees=random.choice([30, 60, 90, 120, 150, 180, 210, 240, 270, 300, 330, 360])),  # Randomly choose from the specified angles
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),  # Randomly change the brightness, contrast, saturation, and hue
    transforms.ToTensor(),                   # Convert the image to a PyTorch tensor with shape (C, H, W)
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize the image with mean and std values across RGB channels
])

validation_transofrm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

**Loading the Dataset**

In [None]:
data_directory = Path(config['data_directory'])
dataset = datasets.ImageFolder(root=data_directory, transform=transform)

**Spiliting into the Reproducible Train, Validation, and Test Datasets**

In [None]:
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split
import numpy as np

In [None]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

train_and_val_share = 0.85  # 85% for train + validation, 15% for test
train_share = 0.85      # 85% of the remaining 85% for training, 15% for validation

# Split indices for train_val and test sets
train_val_indices, test_indices = train_test_split(
    list(range(len(dataset))), 
    test_size=1 - train_and_val_share, 
    random_state=seed, 
    shuffle=True
)

# Further split train_val_indices into training and validation
train_indices, val_indices = train_test_split(
    train_val_indices, 
    test_size=1 - train_share, 
    random_state=seed, 
    shuffle=True
)

train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)
test_dataset = Subset(dataset, test_indices)  # Test dataset remains untouched

print(f' Train Size: {len(train_dataset)} \n Validation Size: {len(val_dataset)} \n Test Size: {len(test_dataset)}')

**Defining Train and validation Dataloaders**

In [None]:
batch_size = config['batch_size']
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
id2label = {v: k for k, v in dataset.class_to_idx.items()}
label2id = dataset.class_to_idx
print(id2label)

# Training Model

**Load the pretrained ViT model**

In [None]:
from timm import create_model
pre_trained_model = config['pre_trained_model']
model = create_model(pre_trained_model, pretrained=True, num_classes=len(train.dataset.classes))

**Moving model to GPU if available**

In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("MPS is available and will be used.")
else:
    device = torch.device("cpu")
    print("MPS is not available, using CPU.")

**Defining loss function and optimizer**

In [None]:
import torch.nn as nn
import torch.optim as optim

In [None]:
criterion = nn.CrossEntropyLoss()  # For classification
optimizer = optim.Adam(model.parameters(), lr=1e-4)

**Implementing the Early Stop Method**

In [None]:
def early_stopping(val_loss, best_loss, patience_counter, patience, min_delta):
    if best_loss is None or val_loss < best_loss - min_delta:
        return val_loss, 0
    else:
        return best_loss, patience_counter + 1  # Increment counter if no improvement

**The Training Loop**

In [None]:
from tqdm import tqdm  # For progress bar

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
print(device)

In [None]:
num_epochs = config['num_epochs']
patience = config['patience']
learning_rate = config['learning_rate']

best_loss = None
patience_counter = 0

for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0

    # Train over batches
    for images, labels in tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()  # Clear previous gradients
        loss.backward()  # Compute gradients
        optimizer.step()  # Update weights

        running_loss += loss.item()

    # Print epoch loss
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader)}')

    # Validation Loop
    model.eval()  # Set model to evaluation mode
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():  # Disable gradient calculation for validation
        for images, labels in validation_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # Print validation metrics
    avg_val_loss = val_loss / len(validation_loader)
    print(f'Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {100 * correct / total:.2f}%')

    # Early stopping check
    best_loss, patience_counter = early_stopping(avg_val_loss, best_loss, patience_counter, patience, learning_rate)

    if patience_counter >= patience:
        print("Early stopping triggered!")
        break


print("Training completed!")

**Saving the Model**

In [None]:
# Save the trained model to a file
models_dir = '../models'
model_name = f'trained_{config["pre_trained_model"]}.pth'
model_path = Path(f'{models_dir}/{model_name}')
torch.save(model.state_dict(), model_path)
print(f"Model saved at {model_path}")