# 0. Setup

In [None]:
!nvidia-smi -L

In [None]:
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, mean_squared_error, classification_report, confusion_matrix
from scipy.stats import entropy
import numpy as np
import pandas as pd
from PIL import Image
import seaborn as sns
from model.vit_for_small_dataset_custom import ViT
from utils.imageset_handler import ImageQualityDataset


In [None]:
def find_pth_files(directory_path):
    """
    Find and return a list of full paths to .pth files in the specified directory.

    Args:
        directory_path (str): The directory path to search for .pth files.

    Returns:
        List[str]: A list of full paths to .pth files.
    """
    pth_files = []
    for root, dirs, files in os.walk(directory_path):
        for file in files:
            if file.endswith(".pth"):
                pth_files.append(os.path.join(root, file))
    return pth_files

# 1. Build Model

### 1.1 Define Variables

In [None]:
image_size=256
patch_size=16
num_classes=5  # Number of classes for image quality levels
dim=1024
depth=6
heads=16
mlp_dim=2048
emb_dropout=0.1

### 1.2 Compile

In [None]:
model = ViT(
    image_size=image_size,
    patch_size=patch_size,
    num_classes=num_classes,
    dim=dim,
    depth=depth,
    heads=heads,
    mlp_dim=mlp_dim,
    emb_dropout=emb_dropout
)
print(model)

# 2 Load Dataset

In [None]:
#weights_path = f'{results_path}/vit_model_20230821_121731_epoch_2of20_valLoss_1.572_valAcc_0.267_batchsize_64_lr_0.0_TestImg.pth'
weights_dir = '/home/maxgan/WORKSPACE/UNI/BA/vision-transformer-for-image-quality-perception-of-individual-observers/results/weights/all_distored_imgs_1'
# weights_path = f'/home/maxgan/WORKSPACE/UNI/BA/vision-transformer-for-image-quality-perception-of-individual-observers/results/weights/TEST/vit_model_20230821_120855_epoch_16of20_valLoss_7.457_valAcc_0.233_batchsize_64_lr_0.0_TestImg.pth'

csv_file = '/home/maxgan/WORKSPACE/UNI/BA/vision-transformer-for-image-quality-perception-of-individual-observers/assets/Test/AccTestCsv/objectiveAccTest.csv'
dataset_root =  '/home/maxgan/WORKSPACE/UNI/BA/vision-transformer-for-image-quality-perception-of-individual-observers/assets/Test/TestImg'
batch_size = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### 2.1 Add Augmentation (Transformation)

In [None]:
# Define the normalization parameters (mean and std)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

# Define the transformation including normalization
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

### 2.2 Loading

In [None]:
# Initialize your dataset loader and test dataset
test_dataset = ImageQualityDataset(csv_file,dataset_root, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 3. Evaluate

### 3.1 Evaluating best weights by calculating class probability, MSE of most likely class and weighted sum, mean entropy, accuracy and classification report

In [None]:

# List of different weight files
weight_files = find_pth_files(weights_dir)

results = []
example_pred_results = []

for weight_file in weight_files:
    print(f'Weights-file: {os.path.basename(weight_file)} will be evaluated')
    # Load the model with different weights
    model.load_state_dict(torch.load(weight_file))
    model.eval()

    true_labels = []
    test_preds = []
    entropies = []
    weighted_sums = []
    kl_divs = []
    with torch.no_grad():

        for i, (images, labels) in enumerate(test_loader, 0):
            # images = images.to(device)
            # labels = labels.to(device)
            print(f"Example Prediction of Batch: {i}")
            outputs, _ = model(images)
            _, preds = torch.max(outputs, 1)
            test_preds.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())


            # Convert logits to probabilities
            probabilities = nn.functional.softmax(outputs, dim=1)
            # Format probabilities in a readble way


            # Calculate Entropy
            entropy_values = entropy(probabilities.numpy(),base=2, axis=1)
            # Format entropies in a readble way
            entropies.extend(entropy_values)

            # Convert labels to one-hot encoded format
            labels_one_hot = torch.zeros(probabilities.size(), dtype=torch.float32)
            labels_one_hot.scatter_(1, labels.unsqueeze(1), 1)

            # Calculate KL Divergence
            kl_div = torch.nn.functional.kl_div(torch.log(probabilities), labels_one_hot, reduction='none')
            kl_div = torch.sum(kl_div, dim=1).numpy()
            kl_divs.extend(kl_div)

            # Define weighting factors
            weighting_factors = [0,1,2,3,4]
            # Calculate the weighted sum of probabilities
            weighted_sum = torch.sum(probabilities * torch.tensor(weighting_factors), dim=1).cpu().numpy()
            # Format weighted sum in a readble way
            weighted_sums.extend(weighted_sum)

            # Example printout for the first batch (you can customize as needed)
            if i == 0:
                example_pred_result = {
                    "Weights File": os.path.basename(weight_file),
                    "True Label": labels.cpu().numpy()[i],
                    "Predicted Label": preds.cpu().numpy()[i],
                    "Weighted Sum of Probability": weighted_sum[i],
                    "Predicted Class Probability": probabilities[i],
                    "Entropy Value": entropy_values[i],
                    "KL Divergence": kl_divs[i],
                }
                example_pred_results.append(example_pred_result)
            print(f'True-Label: {labels.cpu().numpy()[0]}')
            print(f'Predicted-Label: {preds.cpu().numpy()[0]}')
            print(f'Weighted Sum of Probability: {round(weighted_sum[0],4)}')  # Gewichtete Summe der Wahrscheinlichkeiten
            print(f'Predicted-Class-Probality: {[round(prob,4) for prob in probabilities[0].numpy()]}')
            print(f'Entropy Value: {round(entropy_values[0],4)}') # High Value: spreading; Low Value: concentrated
            print(f'KL Divergence: {round(kl_div[0],4)}\n')



    # Calculate the MSE of weighted sum and ground truth
    mse_weighted = mean_squared_error(true_labels, weighted_sums)

    # Calculate the MSE of most likely class and ground truth
    mse = mean_squared_error(true_labels, test_preds)
    
    # Calculate the Mean Entropy
    mean_entropy = np.mean(entropies)

    # Calculate the Mean KL Divergence
    mean_kl_div = np.mean(kl_divs)

    # Calculate Accuracy
    accuracy = accuracy_score(true_labels, test_preds)

    # Generate classification report
    class_report = classification_report(true_labels, test_preds)

    # Generate confusion matrix
    confusion = confusion_matrix(true_labels, test_preds)
    
    print('Model Summary:')
    print(f'Weight: {os.path.basename(weight_file)}, Accuracy: {accuracy}, Mean Entropy: {mean_entropy}, Mean KL Div: {mean_kl_div:.4f}, weighted mean mse: {mse_weighted},\nClassification Report:\n{class_report}')
    # Save confusion matrix as a figure
    plt.figure(figsize=(8, 6))
    sns.heatmap(confusion, annot=True, fmt="d", cmap="Blues", cbar=False)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.savefig(weight_file.replace(".pth", "_confusion.png"))
    plt.close()

    # Store the results
    results.append({
        "Weights File": os.path.basename(weight_file),
        "Accuracy": accuracy,
        "MSE": mse,
        "MSE weighted": mse_weighted,
        "Mean Entropy": mean_entropy,
        "Mean KL Divergence": mean_kl_div, 
        "Classification Report": class_report
    })

# Create a DataFrame and save to CSV
results_df = pd.DataFrame(results)
results_path = os.path.join(weights_dir, "model_comparison_results.csv")
results_df.to_csv(results_path, index=False)

# Save example printouts to a CSV file for this model
example_printouts_df = pd.DataFrame(example_pred_results)
example_printout_file = os.path.join(weights_dir, "model_comparison_results_examples.csv")
example_printouts_df.to_csv(example_printout_file, index=False)

### 3.2 Plot Distribution

In [None]:
csv_file = "/home/maxgan/WORKSPACE/UNI/BA/TIQ/assets/Test/AccTestCsv/shinyxAccTest20-01-2023.csv"
output_image_path = "/home/maxgan/WORKSPACE/UNI/BA/TIQ/assets/Test/AccTestCsv/rating_distribution_shinyxAccTest.png"

In [None]:
data = pd.read_csv(csv_file, header=None, skiprows=1)

# Map rating values to their corresponding labels
rating_labels = {
    1: "Bad",
    2: "Insufficient",
    3: "Fair",
    4: "Good",
    5: "Excellent"
}
data["Rating_Label"] = data[1].map(rating_labels)

# Group data by Rating_Label and count occurrences
class_counts = data["Rating_Label"].value_counts().sort_index()
# Calculate total number of images
total_images = class_counts.sum()

In [None]:
# Create a bar chart
plt.figure(figsize=(10, 6))
class_counts.plot(kind="bar", color='skyblue')
plt.title("Image Rating Distribution Person1 (shinyx)")
plt.xlabel("Rating")
plt.ylabel("Number of Images")
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig(output_image_path)
plt.show()
# Display the table
print("Rating Distribution Table:")
print(class_counts)

### 3.2 Plot Mean Attention

In [None]:
# Resize and normalize attention scores to match the original image patch dimensions
def resize_and_normalize_attention_maps(attention_maps, image_patch_size, image_size):
    resized_and_normalized_attention_maps = []
    for attention_map in attention_maps:
        # print(f"Function: {attention_map.shape}")
        # Resize the attention map to match image patch dimensions
        resized_attention_map = np.zeros((image_size, image_size))
        for i in range(attention_map.shape[0]): # seq_length
            for j in range(attention_map.shape[1]): # seq_length
                # Compute the coordinates in the resized attention map
                x_start = i * image_patch_size
                x_end = (i + 1) * image_patch_size
                y_start = j * image_patch_size
                y_end = (j + 1) * image_patch_size

                # Resize the attention map and add it to the corresponding region
                resized_attention_map[x_start:x_end, y_start:y_end] = attention_map[i, j]

        # Normalize the attention map to range [0, 1]
        min_value = np.min(resized_attention_map)
        max_value = np.max(resized_attention_map)
        normalized_attention_map = (resized_attention_map - min_value) / (max_value - min_value)

        resized_and_normalized_attention_maps.append(normalized_attention_map)

    return resized_and_normalized_attention_maps

In [None]:
weight_file = 'results/weights/all_distored_imgs_1/vit_model_20230911_064440_epoch_148of150_valLoss_0.108_valAcc_0.953_batchsize_128_lr_0.0_allDistorted.pth'

image_size=256
patch_size=16
num_classes=5  # Number of classes for image quality levels
dim=1024
depth=6
heads=16
mlp_dim=2048
emb_dropout=0.1

model = ViT(
    image_size=image_size,
    patch_size=patch_size,
    num_classes=num_classes,
    dim=dim,
    depth=depth,
    heads=heads,
    mlp_dim=mlp_dim,
    emb_dropout=emb_dropout,
    # pool='mean'
)
print(model)

model.load_state_dict(torch.load(weight_file))
model.eval()

In [None]:
image_path = 'assets/Test/TestImg/596ILSVRC2013_train_00009848.JPEG_I5_Q95.jpeg'
image_path = '/home/maxgan/Downloads/vit_model_20230909_021136_epoch_89of150_valLoss_0.134_valAcc_0.941_batchsize_128_lr_0.0_allDistorted_confusion.jpg'

image_org = Image.open(image_path)


image = transform(image_org)
with torch.no_grad():
    out, att = model(image.unsqueeze(0))
    print(f"Out: {out.shape}")
    _, preds = torch.max(out, 1)
    print(f"Predicted Class: {preds.cpu().numpy()[0]}")
    probabilities = nn.functional.softmax(out, dim=1)
    # Format probabilities in a readble way
    formatted_probs = [[f'{p:.4f}' for p in prob_list] for prob_list in probabilities.numpy()]
    print(f"Probabilities: {formatted_probs[0]}")    

    # Define weighting factors
    weighting_factors = [0,1,2,3,4]
    # Calculate the weighted sum of probabilities
    weighted_sum = torch.sum(probabilities * torch.tensor(weighting_factors), dim=1).cpu().numpy()
    # Format weighted sum in a readble way
    formatted_weighted_sum = [f'{sum:.4f}' for sum in weighted_sum]
    print(f"Predicted Class (weighted): {formatted_weighted_sum[0]}")

    print(f"Attention - Shape: {att.shape}") # batch, layers (depth), heads, sequence length, sequence length
    attn_patches = att[:, :, :,1:, 1:] # No Class-Token
    attn_patches = attn_patches.squeeze(0) # No Batch
    att_mean = torch.mean(attn_patches, dim=1) # Mean of Heads

    att_mean_scores = att_mean.cpu().numpy()

    print(f"Attention - Mean: {att_mean.shape}") # layers (depth), sequence length, sequence length
    # Overlay attention maps on the image patches
   # Reverse the normalization
    image = image.permute(1, 2, 0)  # Convert to HxWxC format
    image = image * np.array(std) + np.array(mean)

    print(f"Image: {image.shape}")
    # Plot the original image with overlaid attention maps
    plt.figure(figsize=(8, 8))
    plt.imshow(image)  # Display the original image
    # Resize and normalize attention maps
    resized_and_normalized_attention_maps = resize_and_normalize_attention_maps(
        att_mean_scores, patch_size, image_size
    )
    
    resized_and_normalized_attention_maps = np.stack(resized_and_normalized_attention_maps, axis=0)
    resized_and_normalized_attention_map = np.mean(resized_and_normalized_attention_maps, axis=0)
    # Plot the original image with overlaid attention maps
    plt.figure(figsize=(8, 8))
    plt.imshow(image)
    plt.imshow(resized_and_normalized_attention_map, alpha=0.7, cmap='viridis', interpolation='nearest')
    plt.title(f'Mean Attention')
    plt.colorbar()
    plt.show()


### Plot Attention Per Layer

In [None]:
image_path = 'assets/Test/TestImg/25519ILSVRC2014_train_00005904.JPEG_I3_Q22.jpeg'

image = Image.open(image_path)

image = transform(image)

with torch.no_grad():
    out, att = model(image.unsqueeze(0))
    _, preds = torch.max(out, 1)
    print(f"Predicted Class: {preds.cpu().numpy()[0]}")
    probabilities = nn.functional.softmax(out, dim=1)
    print(f"Probabilities: {probabilities.cpu().numpy()[0]}")
    print(f"Attention - Shape: {att.shape}") # batch, layers (depth), heads, sequence length, sequence length
    attn_patches = att[:, :, :,1:, 1:] # No Class-Token
    attn_patches = attn_patches.squeeze(0) # No Batch
    att_mean = torch.mean(attn_patches, dim=1) # Mean of Heads
    att_mean_scores = att_mean.cpu().numpy()

    print(f"Attention - Mean: {att_mean.shape}") # layers (depth), sequence length, sequence length
    # Overlay attention maps on the image patches
    image = image.permute(1,2,0)
    print(f"Image: {image.shape}")
    # Plot the original image with overlaid attention maps
    plt.figure(figsize=(8, 8))
    plt.imshow(image)  # Display the original image
    # Resize and normalize attention maps
    resized_and_normalized_attention_maps = resize_and_normalize_attention_maps(
        att_mean_scores, patch_size, image_size
    )
    # Overlay attention maps on the image patches
    for i, attention_map in enumerate(resized_and_normalized_attention_maps):
        print(f"Resized Attention Shape: {attention_map.shape}")
        # Plot the original image with overlaid attention maps
        plt.figure(figsize=(8, 8))
        plt.imshow(image)
        plt.imshow(attention_map, alpha=0.7, cmap='viridis', interpolation='nearest')
        plt.title(f'Layer {i+1} Attention')
        plt.colorbar()
        plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Resize attention maps
resized_attention_maps = resize_and_normalize_attention_maps(att_mean_scores, patch_size, image_size)

image_path = 'assets/Test/TestImg/25519ILSVRC2014_train_00005904.JPEG_I3_Q22.jpeg'

image = Image.open(image_path)
image = transform(image)
# Plot the original image with overlaid attention maps
plt.figure(figsize=(8, 8))
plt.imshow(image.permute(1,2,0))  # Display the original image
# Resize and normalize attention maps
resized_and_normalized_attention_maps = resize_and_normalize_attention_maps(
    att_mean_scores, patch_size, image_size
)

# Overlay attention maps on the image patches
for i, attention_map in enumerate(resized_attention_maps):
    # Plot the original image with overlaid attention maps
    plt.figure(figsize=(8, 8))
    plt.imshow(image.permute(1,2,0))  # Display the original image
    plt.imshow(attention_map, alpha=0.7, cmap='viridis', interpolation='nearest')
    plt.title(f'Layer {i+1} Attention')
    plt.colorbar()
    plt.show()
