# 0. Setup

In [None]:
!nvidia-smi -L

In [None]:
import os
import datetime
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
import numpy as np
import pandas as pd


from model.vit_for_small_dataset import ViT
from utils.imageset_handler import ImageQualityDataset

# 1. Build Model

### 1.1 Define Variables

In [None]:
image_size=256
patch_size=16
num_classes=5  # Number of classes for image quality levels
dim=1024
depth=6
heads=16
mlp_dim=2048
emb_dropout=0.1


pretrained_model_path = None
num_epochs = 50
results_path = '/home/maxgan/WORKSPACE/UNI/BA/vision-transformer-for-image-quality-perception-of-individual-observers/results/weights/all_distored_imgs'

dataset_root = '/home/maxgan/WORKSPACE/UNI/BA/vision-transformer-for-image-quality-perception-of-individual-observers/assets/DatasetObjective/allDistorted'
csv_file = '/home/maxgan/WORKSPACE/UNI/BA/vision-transformer-for-image-quality-perception-of-individual-observers/assets/DatasetObjective/objective_imagesquality_scores.csv'

### 1.2 Compile

In [None]:
model = ViT(
    image_size=image_size,
    patch_size=patch_size,
    num_classes=num_classes,
    dim=dim,
    depth=depth,
    heads=heads,
    mlp_dim=mlp_dim,
    emb_dropout=emb_dropout
)
print(model)

### 1.3 Load pretrained weights

In [None]:
if pretrained_model_path:
    model.load_state_dict(torch.load(pretrained_model_path))
    print(model)

# 2 Load Dataset

### 2.1 Add Augmentation (Transformation)

In [None]:
transform = transforms.Compose([
    transforms.RandomResizedCrop(256),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize with ImageNet mean and std
])

### 2.2 Create Dataset

In [None]:
dataset = ImageQualityDataset(csv_file,dataset_root,transform=transform)

### 2.3 Split the dataset into training and validation sets

In [None]:
test_size = 0.2
num_train = int(len(dataset)* (1-test_size))
num_val = len(dataset) - num_train

print('Splitting Dataset..')
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [num_train, num_val])

print(f"Number of Data to train: {num_train}")
print(f"Number of Data to validate: {num_val}")

# 3. Train

### 3.1 Define Training Parameters

In [None]:
learning_rate = 1e-4/2
batch_size = 128

### 3.2 Init Optimizer, loss function and dataloader

In [None]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
mse_criterion = nn.MSELoss()

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

### 3.3 Train-Loop

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
best_val_loss = float('inf')
best_model_weights = None

In [None]:
train_losses = []
val_losses = []
    
# Initialize a list to store model results
model_results = []
print("Starting training...")
for epoch in range(num_epochs):
    # Training
    model.train()
    train_loss = 0.0
    train_preds = []
    train_labels = []

    for _, (images, labels) in enumerate(train_dataloader):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        
        print("Training:", outputs.shape,labels.shape)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)  # Get predicted labels
        train_preds.extend(preds.cpu().numpy())  # Extend the list of predictions
        train_labels.extend(labels.cpu().numpy())  # Extend the list of true labels
        train_accuracy = accuracy_score(train_labels, train_preds)

    # Validation
    model.eval()
    val_loss = 0.0
    val_preds = []
    val_labels = []

    with torch.no_grad():
        for _, (images, labels) in enumerate(val_dataloader):
            images = images.to(device)
            labels = labels.to(device)
            
            
            #MSE
            #outputs = model(images)
            #labels = labels.unsqueeze(0)
            #labels = labels.float()
            #print("Validation:",outputs.shape,labels.shape)
            #mse_loss = mse_criterion(outputs, labels)
            #val_loss += mse_loss.item() * images.size(0)

            
            #CROSS-ENTROPY
            outputs = model(images)
            print("Validation:",outputs.shape,labels.shape)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)
            
            
            
            _, preds = torch.max(outputs, 1)  # Get predicted labels
            val_preds.extend(preds.cpu().numpy())  # Extend the list of predictions
            val_labels.extend(labels.cpu().numpy())  # Extend the list of true labels



    val_accuracy = accuracy_score(val_labels, val_preds)
    train_loss /= len(train_dataset)
    val_loss /= len(val_dataset)
    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss:.4f}, Training Acc: {train_accuracy:.4f}, Validation Loss: {val_loss:.4f}, Validation Acc: {val_accuracy:.4f}')

    # Calculate and store the losses
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    
    # Save the weights with the best validation loss
    if val_loss < best_val_loss:
        # Delete the previously saved best model
        #if best_model_weights is not None:
        #    os.remove(best_model_path)

        # Update the best validation loss and save the new best model
        best_val_loss = val_loss
        best_model_weights = model.state_dict().copy()

        # Get the current timestamp
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        last_folder = os.path.basename(dataset_root)
        # Use the timestamp and transfer learning information as a name extension
        model_name = f"vit_model_{timestamp}_epoch_{epoch+1}of{num_epochs}_valLoss_{best_val_loss:.3f}_valAcc_{val_accuracy:.3f}_batchsize_{batch_size}_lr_{learning_rate:.1f}_{last_folder}.pth"
        best_model_path = os.path.join(results_path, model_name)
        torch.save(best_model_weights, best_model_path)
        
        # After saving the best model
        model_info = {
            'model_name': model_name,
            'validation_loss': val_loss,
            'validation_accuracy': val_accuracy,
            'batch_size': batch_size,
            'learning_rate': learning_rate,
            'epoch': epoch + 1
        }
        model_results.append(model_info)

In [None]:
# Create a DataFrame from model_results
results_df = pd.DataFrame(model_results)

# Save the results as a CSV file
results_csv_path = os.path.join(results_path, f'model_results_{last_folder}.csv')
results_df.to_csv(results_csv_path, index=False)

# Save the Matplotlib figure with the same basename as the saved model
figure_name = os.path.splitext(model_name)[0] + '.png'
figure_path = os.path.join(results_path, figure_name)
# Plot the losses
plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss')
plt.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.savefig(figure_path)
plt.show()

In [None]:
%%bash
sudo shutdown -h now

# 4. Plotting and Evaluating

### 4.1 Define Model and Dataset to evaluat

In [None]:
#weights_path = f'{results_path}/vit_model_20230821_121731_epoch_2of20_valLoss_1.572_valAcc_0.267_batchsize_64_lr_0.0_TestImg.pth'
weights_path = f'/home/maxgan/WORKSPACE/UNI/BA/vision-transformer-for-image-quality-perception-of-individual-observers/results/weights/TEST/vit_model_20230821_120855_epoch_16of20_valLoss_7.457_valAcc_0.233_batchsize_64_lr_0.0_TestImg.pth'

csv_file = '/home/maxgan/WORKSPACE/UNI/BA/vision-transformer-for-image-quality-perception-of-individual-observers/assets/Test/TestImg'
dataset_root = '/home/maxgan/WORKSPACE/UNI/BA/vision-transformer-for-image-quality-perception-of-individual-observers/assets/Test/AccTestCsv/objectiveAccTest.csv'

In [None]:
model = ViT(
    image_size=image_size,
    patch_size=patch_size,
    num_classes=num_classes,
    dim=dim,
    depth=depth,
    heads=heads,
    mlp_dim=mlp_dim,
    emb_dropout=emb_dropout
)
model.load_state_dict(torch.load(weights_path))

dataset = ImageQualityDataset(csv_file,dataset_root)

### 4.2 Plot Confusion Matrix

In [None]:
from sklearn.metrics import confusion_matrix
def prediction_quality(image_path, model):
    # Define the image augmentation transformations
    transform = transforms.Compose([
        transforms.RandomResizedCrop(256),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize with ImageNet mean and std
    ])

    # image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image_path).unsqueeze(0)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()
    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        outputs = model(image_tensor)
        _, predicted_label = torch.max(outputs, 1)
        predicted_rating = predicted_label.item() + 1  # Adding 1 to convert 0-based index to 1-based rating
    return predicted_rating

def compare(dataset, model):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()

    predictions = []
    ground_truth_ratings = []

    with torch.no_grad():
        for image_path, rating in dataset:
            # image_path = os.path.join(dataset_root, image_path)
            predicted_rating = prediction_quality(image_path, model)

            predictions.append(predicted_rating)
            ground_truth_ratings.append(rating + 1)

    return predictions, ground_truth_ratings

In [None]:
# Compare model predictions with ground truth ratings for the subjective datasets
predictions, ground_truth_ratings = compare(dataset, model)

# Create the confusion matrix
conf_matrix = confusion_matrix(ground_truth_ratings, predictions)

In [None]:
# figure_name = f"Confusion_Matrix_{os.path.splitext(model_name)[0]}.png"
figure_name = f"Confusion_Matrix_{os.path.splitext(os.path.basename(weights_path))[0] + '.png'}"
figure_path = os.path.join(results_path, figure_name)

# Plot the confusion matrix as a heatmap with annotations
plt.figure(figsize=(6, 6))
plt.imshow(conf_matrix, cmap='Blues', interpolation='nearest')
plt.colorbar()
plt.xlabel('Predicted Ratings')
plt.ylabel('Ground Truth Ratings')
plt.title('Confusion Matrix')
plt.xticks(np.arange(5), np.arange(1, 6))
plt.yticks(np.arange(5), np.arange(1, 6))

# Add text annotations for true positives and false positives
for i in range(conf_matrix.shape[0]):
    for j in range(conf_matrix.shape[1]):
        # Customize text color for light blue boxes (when the value is high)
        if conf_matrix[i, j] > conf_matrix.max() / 2:
            plt.text(j, i, str(conf_matrix[i, j]), ha='center', va='center', color='white')
        else:
            plt.text(j, i, str(conf_matrix[i, j]), ha='center', va='center', color='black')
plt.savefig(figure_path)
plt.show()

### 4.3 Plot Distribution

In [None]:
csv_file = "/home/maxgan/WORKSPACE/UNI/BA/TIQ/assets/Test/AccTestCsv/shinyxAccTest20-01-2023.csv"
output_image_path = "/home/maxgan/WORKSPACE/UNI/BA/TIQ/assets/Test/AccTestCsv/rating_distribution_shinyxAccTest.png"

In [None]:
data = pd.read_csv(csv_file, header=None, skiprows=1)

# Map rating values to their corresponding labels
rating_labels = {
    1: "Bad",
    2: "Insufficient",
    3: "Fair",
    4: "Good",
    5: "Excellent"
}
data["Rating_Label"] = data[1].map(rating_labels)

# Group data by Rating_Label and count occurrences
class_counts = data["Rating_Label"].value_counts().sort_index()
# Calculate total number of images
total_images = class_counts.sum()

In [None]:
# Create a bar chart
plt.figure(figsize=(10, 6))
class_counts.plot(kind="bar", color='skyblue')
plt.title("Image Rating Distribution Person1 (shinyx)")
plt.xlabel("Rating")
plt.ylabel("Number of Images")
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig(output_image_path)
plt.show()
# Display the table
print("Rating Distribution Table:")
print(class_counts)

### 4.4 Plot SSMI

In [None]:
image_path1 = "/home/maxgan/WORKSPACE/UNI/BA/TIQ/assets/Test/TestImg/4359ILSVRC2013_train_00022365.JPEG_I1_Q8.jpeg"  # Replace with your image file path
image_path2 = "/home/maxgan/WORKSPACE/UNI/BA/TIQ/assets/DatasetObjective/allDistorted/4359ILSVRC2013_train_00022365.JPEG_I5_Q67.jpeg"  # Replace with your image file path

In [None]:
from PIL import Image
from skimage.metrics import structural_similarity as ssim

def compare_ssim(image_path1, image_path2):
    # Open the images using PIL
    img1 = Image.open(image_path1)
    img2 = Image.open(image_path2)

    # Convert images to grayscale (if they are not already)
    if img1.mode != 'L':
        img1 = img1.convert('L')
    if img2.mode != 'L':
        img2 = img2.convert('L')

    # Convert PIL Images to numpy arrays for skimage's ssim function
    img1_array = np.array(img1)
    img2_array = np.array(img2)

    # Calculate SSIM
    ssim_value = ssim(img1_array, img2_array)

    return img1, img2, ssim_value

In [None]:
img1, img2, similarity = compare_ssim(image_path1, image_path2)
print(f"SSIM between the images: {similarity:.4f}")

# Plotting
fig, axs = plt.subplots(1, 3, figsize=(12, 4))

axs[0].imshow(img1, cmap='gray')
axs[0].set_title('Image 1')

axs[1].imshow(img2, cmap='gray')
axs[1].set_title('Image 2')

axs[2].text(0.5, 0.5, f"SSIM: {similarity:.4f}", fontsize=12,
            horizontalalignment='center', verticalalignment='center')
axs[2].axis('off')
axs[2].set_title('SSIM')

plt.tight_layout()
plt.show()

### 4.4 Plot Attentionmap

In [None]:
image_path = "/home/maxgan/WORKSPACE/UNI/BA/vision-transformer-for-image-quality-perception-of-individual-observers/assets/Test/TestImg/802ILSVRC2013_train_00013401.JPEG_I2_Q15.jpeg"

In [None]:
import torch
from captum.attr import LayerConductance

# Load your trained ViT model
model = ViT(
    image_size=image_size,
    patch_size=patch_size,
    num_classes=num_classes,
    dim=dim,
    depth=depth,
    heads=heads,
    mlp_dim=mlp_dim,
    emb_dropout=emb_dropout
)
model.load_state_dict(torch.load(weights_path))

# Load a sample image tensor (single image)
image = Image.open(image_path)
image_tensor = transform(image)

# Set the model to evaluation mode
model.eval()

# Calculate the attention maps using LayerConductance
# Choose a layer to visualize (e.g., the last layer or a specific layer)
layer = model.transformer.encoder[-1].norm2

# Instantiate the LayerConductance
layer_cond = LayerConductance(model, layer)

# Compute attention maps for the image
attributions = layer_cond.attribute(image_tensor, target=1)  # Modify the target class index

# You can modify this code to visualize the attention maps for different layers and images
plt.figure(figsize=(10, 10))
for i in range(attributions.shape[1]):
    plt.subplot(4, 4, i + 1)
    plt.imshow(attributions[0, i].detach().cpu().numpy(), cmap='viridis')
    plt.axis('off')
    plt.title(f'Attention Map {i}')
plt.tight_layout()
plt.show()


# 5. Run Model