# 0. Setup

In [None]:
!nvidia-smi -L

In [None]:
import os
import datetime
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from misc.helpers import calculate_label_distributions
import numpy as np
import pandas as pd


from model.vit_for_small_dataset import ViT
from utils.imageset_handler import ImageQualityDataset

# 1. Build Model

### 1.1 Define Variables

In [None]:
image_size=256
patch_size=16
num_classes=5  # Number of classes for image quality levels
dim=1024
depth=6
heads=16
mlp_dim=2048
emb_dropout=0.1


pretrained_model_path = None
# pretrained_model_path = 'results/weights/AIO1.2.pth'
num_epochs = 100
results_path = './results/weights/AIO0'
os.makedirs(results_path, exist_ok=True)

dataset_root = 'assets/Dataset/DS0'
csv_file = 'assets/Dataset/Obs0.csv'

### 1.2 Compile

In [None]:
model = ViT(
    image_size=image_size,
    patch_size=patch_size,
    num_classes=num_classes,
    dim=dim,
    depth=depth,
    heads=heads,
    mlp_dim=mlp_dim,
    emb_dropout=emb_dropout
)
print(model)

### 1.3 Load pretrained weights

In [None]:
if pretrained_model_path:
    model.load_state_dict(torch.load(pretrained_model_path))
    print(model)

# 2 Load Dataset

### 2.1 Add Augmentation (Transformation)

In [None]:
transform = transforms.Compose([
    transforms.RandomResizedCrop(256),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize with ImageNet mean and std
])

### 2.2 Create Dataset

In [None]:
dataset = ImageQualityDataset(csv_file,dataset_root,transform=transform)

### 2.3 Split the dataset into training and validation sets

In [None]:
test_size = 0.2
num_train = int(len(dataset)* (1-test_size))
num_val = len(dataset) - num_train

print('Splitting Dataset..')
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [num_train, num_val])

print(f"Number of Data to train: {num_train}")
print(f"Number of Data to validate: {num_val}")

# 3. Train

### 3.1 Define Training Parameters

In [None]:
learning_rate = 5e-5
batch_size = 128

### 3.2 Init Optimizer, loss function and dataloader

In [None]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
mse_criterion = nn.MSELoss()
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

### 3.3 Train-Loop

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
best_val_loss = float('inf')
best_model_weights = None

In [None]:
train_losses = []
val_losses = []
    
# Initialize a list to store model results
model_results = []
print("Starting training...")
for epoch in range(num_epochs):
    # Training
    model.train()
    train_loss = 0.0
    train_preds = []
    train_labels = []

    for _, (images, labels) in enumerate(train_dataloader):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)  # Get predicted labels
        train_preds.extend(preds.cpu().numpy())  # Extend the list of predictions
        train_labels.extend(labels.cpu().numpy())  # Extend the list of true labels
    train_accuracy = accuracy_score(train_labels, train_preds)

    # Validation
    model.eval()
    val_loss = 0.0
    val_preds = []
    val_labels = []

    with torch.no_grad():
        for _, (images, labels) in enumerate(val_dataloader):
            images = images.to(device)
            labels = labels.to(device)
            
            # CROSS-ENTROPY
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)  # Get predicted labels
            val_preds.extend(preds.cpu().numpy())  # Extend the list of predictions
            val_labels.extend(labels.cpu().numpy())  # Extend the list of true labels



    val_accuracy = accuracy_score(val_labels, val_preds)
    train_loss /= len(train_dataset)
    val_loss /= len(val_dataset)
    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss:.4f}, Training Acc: {train_accuracy:.4f}, Validation Loss: {val_loss:.4f}, Validation Acc: {val_accuracy:.4f}')
    # Step the ReduceLROnPlateau scheduler with the validation loss
    scheduler.step(val_loss)
    # Calculate and store the losses
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    
    # Save the weights with the best validation loss
    if val_loss < best_val_loss:
        # Delete the previously saved best model
        #if best_model_weights is not None:
        #    os.remove(best_model_path)

        # Update the best validation loss and save the new best model
        best_val_loss = val_loss
        best_model_weights = model.state_dict().copy()

        # Get the current timestamp
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        last_folder = os.path.basename(dataset_root)
        # Use the timestamp and transfer learning information as a name extension
        model_name = f"vit_model_{timestamp}_epoch_{epoch+1}of{num_epochs}_valLoss_{best_val_loss:.3f}_valAcc_{val_accuracy:.3f}_batchsize_{batch_size}_lr_{learning_rate:.1f}_{last_folder}.pth"
        best_model_path = os.path.join(results_path, model_name)
        torch.save(best_model_weights, best_model_path)
        
        # After saving the best model
        model_info = {
            'model_name': model_name,
            'validation_loss': val_loss,
            'validation_accuracy': val_accuracy,
            'batch_size': batch_size,
            'learning_rate': learning_rate,
            'epoch': epoch + 1
        }
        model_results.append(model_info)

Save results

In [None]:
# Create a DataFrame from model_results
results_df = pd.DataFrame(model_results)

# Save the results as a CSV file
results_csv_path = os.path.join(results_path, f'model_results_{last_folder}.csv')
results_df.to_csv(results_csv_path, index=False)

# Save the Matplotlib figure with the same basename as the saved model
figure_name =  f'Train_Val_Curve_{last_folder}.png'
figure_path = os.path.join(results_path, figure_name)
# Plot the losses
plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss')
plt.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.savefig(figure_path)
plt.show()

# 4 Train Persons AIO

### 4.1 Define Model, Dataset and Parameters

In [None]:
image_size=256
patch_size=16
num_classes=5  # Number of classes for image quality levels
dim=1024
depth=6
heads=16
mlp_dim=2048
emb_dropout=0.1

# pretrained_model_path = None
# pretrained_model_path = 'results/weights/AIO0/vit_model_20230911_064440_epoch_148of150_valLoss_0.108_valAcc_0.953_batchsize_128_lr_0.0_allDistorted.pth'
pretrained_model_path = 'results/weights/FINAL/AIO5/AIO5_2/vit_model_20231204_153455_epoch_16of100_valLoss_1.094_valAcc_0.600_batchsize_128_lr_0.0_DS5.pth'
num_epochs = 100
results_path = 'results/weights/FINAL/AIO5/AIO5_3'
os.makedirs(results_path, exist_ok=True)

train_dataset_root  = 'assets/Dataset/DS5'
val_dataset_root  = 'assets/Test/DSX'
train_csv_file = 'assets/Obs_iterative/Obs5/Obs5_3.csv'
val_csv_file = 'assets/Test/Obs5.csv'

In [None]:
learning_rate = 2.5e-5
batch_size = 128

In [None]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize with ImageNet mean and std
])

In [None]:
model = ViT(
    image_size=image_size,
    patch_size=patch_size,
    num_classes=num_classes,
    dim=dim,
    depth=depth,
    heads=heads,
    mlp_dim=mlp_dim,
    emb_dropout=emb_dropout
)
if pretrained_model_path is not None:
    model.load_state_dict(torch.load(pretrained_model_path))

train_dataset = ImageQualityDataset(train_csv_file,train_dataset_root, transform=transform)
val_dataset = ImageQualityDataset(val_csv_file,val_dataset_root, transform=transform)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
mse_criterion = nn.MSELoss()
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

# KL-Divergence
kl_div_criterion = nn.KLDivLoss(reduction='batchmean')

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
best_val_loss = float('inf')
best_model_weights = None

### 4.2 Train AIO

In [None]:
train_losses = []
val_losses = []
    
# Initialize a list to store model results
model_results = []
print("Starting training...")
for epoch in range(num_epochs):
    # Training
    model.train()
    train_loss = 0.0
    train_preds = []
    train_labels = []

    for _, (images, _, labels) in enumerate(train_dataloader):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        
        true_distributions = calculate_label_distributions(labels, device=device)
        # true_distributions = labels

        # loss = criterion(outputs, labels)
        loss = criterion(outputs, true_distributions)

        # log_model_dist = torch.nn.functional.log_softmax(outputs, dim=1)
        # loss = kl_div_criterion(log_model_dist, true_distributions)

        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)  # Get predicted labels
        train_preds.extend(preds.cpu().numpy())  # Extend the list of predictions
        train_labels.extend(labels.cpu().numpy())  # Extend the list of true labels
    train_accuracy = accuracy_score(train_labels, train_preds)

    # Validation
    model.eval()
    val_loss = 0.0
    val_preds = []
    val_labels = []

    with torch.no_grad():
        for _, (images, _, labels) in enumerate(val_dataloader):
            images = images.to(device)
            labels = labels.to(device)

            true_distributions = calculate_label_distributions(labels=labels, device=device)
            # true_distributions = labels
            # CROSS-ENTROPY
            outputs = model(images)
            loss = criterion(outputs, true_distributions)

            # log_model_dist = torch.nn.functional.log_softmax(outputs, dim=1)
            # loss = kl_div_criterion(log_model_dist, true_distributions)

            # loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)  # Get predicted labels
            val_preds.extend(preds.cpu().numpy())  # Extend the list of predictions
            val_labels.extend(labels.cpu().numpy())  # Extend the list of true labels



    val_accuracy = accuracy_score(val_labels, val_preds)
    train_loss /= len(train_dataset)
    val_loss /= len(val_dataset)
    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss:.4f}, Training Acc: {train_accuracy:.4f}, Validation Loss: {val_loss:.4f}, Validation Acc: {val_accuracy:.4f}')
    # Step the ReduceLROnPlateau scheduler with the validation loss
    scheduler.step(val_loss)
    # Calculate and store the losses
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    
    # Save the weights with the best validation loss
    if val_loss < best_val_loss:
        # Delete the previously saved best model
        #if best_model_weights is not None:
        #    os.remove(best_model_path)

        # Update the best validation loss and save the new best model
        best_val_loss = val_loss
        best_model_weights = model.state_dict().copy()

        # Get the current timestamp
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        last_folder = os.path.basename(train_dataset_root)
        # Use the timestamp and transfer learning information as a name extension
        model_name = f"vit_model_{timestamp}_epoch_{epoch+1}of{num_epochs}_valLoss_{best_val_loss:.3f}_valAcc_{val_accuracy:.3f}_batchsize_{batch_size}_lr_{learning_rate:.1f}_{last_folder}.pth"
        best_model_path = os.path.join(results_path, model_name)
        torch.save(best_model_weights, best_model_path)
        
        # After saving the best model
        model_info = {
            'model_name': model_name,
            'validation_loss': val_loss,
            'validation_accuracy': val_accuracy,
            'batch_size': batch_size,
            'learning_rate': learning_rate,
            'epoch': epoch + 1
        }
        model_results.append(model_info)

Save results

In [None]:
# Extract the digit from the CSV file name
digit = ''.join(filter(str.isdigit, train_csv_file))

# Create a DataFrame from model_results
results_df = pd.DataFrame(model_results)

# Save the results as a CSV file
results_csv_path = os.path.join(results_path, f'training_results_AIO{digit}.csv')
results_df.to_csv(results_csv_path, index=False)


# Save the Matplotlib figure with the same basename as the saved model
figure_name =  f'Train_Val_Curve_AIO{digit}.png'
figure_path = os.path.join(results_path, figure_name)
# Plot the losses
plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss')
plt.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.savefig(figure_path)
plt.show()