<h1>
<center>AI vs. Human: Interpretable Binary Classification with CNNs on the <br>Dalle-Recognition Dataset</center>
</h1>

<font size="3">
This Notebook offers a detailed exploration into images binary classification using Convolutional Nerual Networks (CNN), aiming to differentiate between images created by AI (Class: Fake) and those created by humans (Class: Real) within the "Dalle-Recognition" dataset. Key activities encompass:
<br>
<ul>
<li><strong>Data Transformation and Loading</strong>: Demonstrates essential image transformations for normalization, followed by efficient loading techniques, to prepare the dataset for binary classification.</li>
    
<li><strong>Image Denoising</strong>: Given the unique challenge posed by AI-generated images, where models like DALL-E begin with a noise pattern to generate artwork based on textual prompts, distinguishing these from human-created images necessitates advanced denoising methods. The presence of inherent noise patterns can mislead classifiers into incorrectly tagging human-made art as AI-generated. By applying wavelet-based denoising, the aim is to neutralize this noise, thereby reducing confusion and improving the classification accuracy</li>
    
<li><strong>Model Preparation and Experiments</strong>: The notebook conducts rigorous training sessions. It systematically experiments with 6 distinct architectural designs and various learning rates to find the optimal setup for accurately classifying images into AI-generated or human-created categories.</li>
    
<li><strong>Performance Evaluation</strong>: Employs precise metrics—precision, recall, F1 score, and accuracy—and generates extensive classification reports and confusion matrices. These tools collectively offer in-depth insights into the model's capability to differentiate between the two classes of images effectively.</li>
    
<li><strong>Visualization Techniques</strong>: Incorporates a range of visualization functions to plot metrics over epochs, showcase samples from each dataset category, and illustrate class distribution, facilitating an intuitive grasp of the dataset specifics and the model's learning trajectory.</li>
    
<li><strong>Interpretability Analysis</strong>: Applies the Integrated Gradients technique for a granular analysis of the model's decision-making process. It highlights specific image features that significantly influence the classification outcome, providing transparency and understanding of the model's predictive behavior.</li>
</ul>
</font>

## Generals

<font size="3"> 
Packages import and system configurations. 
</font>

In [None]:
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, SubsetRandomSampler
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

import os
import shutil
from random import sample
import random
from PIL import Image

import json
from captum.attr import IntegratedGradients
from captum.attr import visualization as viz
import matplotlib.cm as cm
from scipy.ndimage import zoom
import torch.nn.functional as F
from sklearn.metrics import classification_report
import pywt


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
current_path = os.getcwd()

<font size="3"> 
Datasets paths. 
</font>

In [None]:
dataset_path = os.path.join(current_path, 'io', 'input', 'dataset')
train_data_path = os.path.join(dataset_path, 'train')
test_data_path = os.path.join(dataset_path, 'test')
metrics_plot_path = os.path.join(current_path, 'io', 'output', 'plots')
saved_models_path = os.path.join(current_path, 'io', 'output', 'models')
results_path = os.path.join(current_path, 'io', 'output', 'results')
experimental_evaluation_path = os.path.join(metrics_plot_path, 'experimental_evaluation')
interpretability_results_path = os.path.join(current_path, 'io', 'output', 'interpretability')


<font size="3"> 
Setting Random Seeds for Reproducibility. 
</font>

In [None]:
def set_seeds(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # For CUDA devices
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seeds()

## Data Transformation & Loading 

<font size="3"> 
This function applies wavelet transform-based denoising to an image, utilizing the discrete wavelet transform (DWT) for noise estimation and reduction across multiple levels of detail. It aims to mitigate noise by adjusting wavelet coefficients based on a calculated threshold, which is derived from the image's noise characteristics, to improve image clarity and quality. 
</font>

In [None]:
def wavelet_denoise(image_data, wavelet='db1', level=1):
    # Perform discrete wavelet transform (DWT)
    coeffs = pywt.wavedec2(image_data, wavelet, level=level)
    
    # Estimate noise sigma using the Median Absolute Deviation (MAD) of the highest level detail coefficients
    sigma = np.median(np.abs(coeffs[-1][-1])) / 0.6745
    
    # Check if sigma is zero or NaN and adjust it
    if sigma == 0 or np.isnan(sigma):
        sigma = 1e-10  # A small positive value to avoid division by zero or NaN operations
    
    threshold = sigma * np.sqrt(2 * np.log(image_data.size))
    
    # Apply threshold to the detail coefficients at each level
    coeffs_thresholded = [coeffs[0]] + [tuple(pywt.threshold(i, threshold, mode='soft') for i in detail) for detail in coeffs[1:]]
    
    # Reconstruct the image from the thresholded coefficients
    denoised_data = pywt.waverec2(coeffs_thresholded, wavelet)
    
    return denoised_data


<font size="3"> 
This function extends the wavelet-based denoising approach to RGB images by applying the wavelet_denoise method separately to each color channel and then recombining them. This approach ensures that noise is effectively reduced in each channel of the color image, enhancing overall image quality while preserving color integrity. 
</font>

In [None]:
def wavelet_denoise_color(image, wavelet='db1', level=1):
    denoised_channels = []
    for channel in range(3):  # Process each channel: R, G, B
        # Extract the current channel
        channel_data = image[:, :, channel]
        # Denoise the current channel using the previous wavelet_denoise function
        denoised_channel = wavelet_denoise(channel_data, wavelet=wavelet, level=level)
        denoised_channels.append(denoised_channel)
        
    # Stack the denoised channels back together
    denoised_image = np.stack(denoised_channels, axis=2)
    return denoised_image


<font size="3">
Theis class is a custom PyTorch transformation that applies wavelet-based denoising to each channel of an RGB image. Designed for integration into PyTorch's data preprocessing pipelines, it converts images to numpy arrays for denoising and then back to PIL images, enabling noise reduction in RGB images before they are passed into neural network models.
</font>

In [None]:
class WaveletDenoiseTransformRGB:
    def __init__(self, wavelet='db1', level=1):
        self.wavelet = wavelet
        self.level = level
    
    def __call__(self, img):
        # Convert PIL image to numpy array
        img_np = np.array(img)
        
        # Apply denoising for RGB image
        denoised_np = wavelet_denoise_color(img_np, wavelet=self.wavelet, level=self.level)
        
        # Convert numpy array back to PIL Image
        img_denoised = Image.fromarray(denoised_np.astype('uint8'), 'RGB')
        return img_denoised
    

<font size="3"> 
This function loads and preprocesses image datasets for training and testing. It checks for dataset existence, applies transformations like resizing and converting to tensor, and returns the datasets along with class names and labels.
</font>

In [None]:
def load_datasets(train_data_path, test_data_path, img_height, img_width, denoising=False):
    
    if not os.path.exists(train_data_path) or not os.path.exists(test_data_path):
        print('Preprocessed dataset does not exist, please created on the Data_preprocessing notebook')
        return False, False, False, False
      
    if denoising:
        data_transform = transforms.Compose([
            transforms.Resize((img_height, img_width)),
            WaveletDenoiseTransformRGB(),
            transforms.ToTensor(),])
    else:
        data_transform = transforms.Compose([
            transforms.Resize((img_height, img_width)),
            transforms.ToTensor(),])

    train_dataset = torchvision.datasets.ImageFolder(root=train_data_path, transform=data_transform)
    test_dataset = torchvision.datasets.ImageFolder(root=test_data_path, transform=data_transform)
    
    class_names = train_dataset.classes
    class_to_label = train_dataset.class_to_idx
    
    return train_dataset, test_dataset, class_names, class_to_label


<font size="3"> 
This function creates data loaders for training, validation, and testing datasets. It optionally splits the training dataset into training and validation sets based on a specified ratio, facilitating model evaluation during training. 
</font>

In [None]:
def get_data_loaders(train_dataset, test_dataset, batch_size, train_val_split_ratio, use_validation_set):
    
    if use_validation_set:
        train_size = int(train_val_split_ratio * len(train_dataset))
        val_size = len(train_dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    else:
        val_loader = []
        
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader, test_loader


## Data analysis

<font size="3"> 
This function displays a set of random images from a given dataset, alongside their class labels, in a grid format.
</font>

In [None]:
def plot_random_images(dataset, class_names, dataset_name, num_images=6):
    fig, axes = plt.subplots(2, 3, figsize=(5, 5))
    fig.suptitle(f"DALL-E Recognition Dataset", fontsize=14)

    for ax in axes.ravel():
        # Choose a random index
        idx = np.random.randint(len(dataset))

        # Get the image and label at the random index
        image, label = dataset[idx]

        # Plot the image
        ax.imshow(np.transpose(image.numpy(), (1, 2, 0)))
        ax.set_title(f"Class: {class_names[label]}")
        ax.axis("off")

    plt.tight_layout()
    plt.savefig(metrics_plot_path + '/Random_Images' + '.pdf')
    plt.show()


<font size="3"> 
This function assesses and visualizes the balance of image classes within train and test datasets. It calculates the number of images per class for both datasets, plots these counts in a comparative bar chart, and saves the chart as a PDF. 
</font>

In [None]:
def plot_class_balance(train_dataset, test_dataset, class_names):
    # Calculate the number of images per class in train dataset
    train_class_counts = {class_name: 0 for class_name in class_names}
    for _, label in train_dataset:
        class_name = class_names[label]
        train_class_counts[class_name] += 1

    # Calculate the number of images per class in test dataset
    test_class_counts = {class_name: 0 for class_name in class_names}
    for _, label in test_dataset:
        class_name = class_names[label]
        test_class_counts[class_name] += 1

    # Plotting
    fig, ax = plt.subplots(figsize=(10, 6))
    class_labels = list(train_class_counts.keys())
    x = range(len(class_labels))

    ax.bar(x, train_class_counts.values(), width=0.4, align='center', label='Train Set')
    ax.bar(x, test_class_counts.values(), width=0.4, align='edge', label='Test Set')
    ax.set_xticks(x)
    ax.set_xticklabels(class_labels)
    ax.set_xlabel('Class')
    ax.set_ylabel('Number of Images')
    ax.set_title('Number of Images per Class in Train and Test Sets')
    ax.legend()

    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(metrics_plot_path + '/Class_Balance' + '.pdf')
    plt.show()
    

<font size="3"> 
This function prints the shape of the first batch of images and labels from a specified data loader, providing a quick overview of batch size and image dimensions. 
</font>

In [None]:
def get_loader_shapes(loader, loader_name):
    print(f"Shapes of batches in {loader_name}:")
    for i, (images, labels) in enumerate(loader):
        print(f"Batch {i+1}: {images.shape}, {labels.shape}\n")
        break
        

## CNN Architectures

<font size="3"> 
This class defines a flexible CNN model for image classification, allowing for variations in architecture based on the model name provided during initialization. It includes configurations for different numbers of convolutional layers, pooling layers, batch normalization, and dropout, enabling experimentation with model complexity and regularization techniques.
</font>

In [None]:
class CNN_Model(nn.Module):
    def __init__(self, model_name, input_channels):
        super(CNN_Model, self).__init__()
        self.model_name = model_name
        
        # Define common layers
        self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, 1)  # Output layer remains the same for all models
        
        # Additional layers for models with batch normalization and more pooling
        self.bn1 = nn.BatchNorm2d(16)
        self.bn2 = nn.BatchNorm2d(32)
        self.bn3 = nn.BatchNorm2d(64)
        self.bn4 = nn.BatchNorm2d(128)

        # Different first fully connected layers for different models
        self.fc1_model1 = nn.Linear(32 * 64 * 64, 128)
        self.fc1_model2 = nn.Linear(64 * 32 * 32, 128)
        self.fc1_model3 = nn.Linear(64 * 16 * 16, 128)
        self.fc1_model4 = nn.Linear(128 * 8 * 8, 128)
        self.fc1_model5 = nn.Linear(128 * 8 * 8, 128)
        self.fc1_model6 = nn.Linear(128 * 4 * 4, 128)

    def forward(self, x):
        if self.model_name == '2Conv1Pool':
            x = F.relu(self.conv1(x))
            x = self.pool(F.relu(self.conv2(x)))
            x = x.view(-1, 64 * 64 * 32)
            x = F.relu(self.fc1_model1(x))
            x = torch.sigmoid(self.fc2(x))
        
        if self.model_name == '3Conv2Pool':
            x = self.pool(F.relu(self.conv1(x)))
            x = F.relu(self.conv2(x))
            x = self.pool(F.relu(self.conv3(x)))
            x = x.view(-1, 64 * 32 * 32)
            x = F.relu(self.fc1_model2(x))
            x = torch.sigmoid(self.fc2(x))
            
        elif self.model_name == '3Conv3Pool':
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = self.pool(F.relu(self.conv3(x)))
            x = x.view(-1, 64 * 16 * 16)
            x = F.relu(self.fc1_model3(x))
            x = torch.sigmoid(self.fc2(x))      
        
        elif self.model_name == '4Conv4Pool':
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x))) 
            x = self.pool(F.relu(self.conv3(x)))  
            x = self.pool(F.relu(self.conv4(x)))
            x = x.view(-1, 128 * 8 * 8)
            x = F.relu(self.fc1_model4(x))
            x = torch.sigmoid(self.fc2(x))
        
        elif self.model_name == '4Conv4Pool_BatchNorm':
            x = self.pool(F.relu(self.bn1(self.conv1(x))))
            x = self.pool(F.relu(self.conv2(x)))
            x = self.pool(F.relu(self.conv3(x)))
            x = self.pool(F.relu(self.conv4(x)))
            x = x.view(-1, 128 * 8 * 8)
            x = F.relu(self.fc1_model5(x))
            x = torch.sigmoid(self.fc2(x))
        
        elif self.model_name == '4Conv4Pool_BatchNorm_Dropout':
            x = self.pool(F.relu(self.bn1(self.conv1(x))))
            x = self.pool(F.relu(self.conv2(x)))
            x = self.pool(F.relu(self.bn3(self.conv3(x))))
            x = self.pool(F.relu(self.conv4(x)))
            x = self.pool(x)
            x = self.dropout(x)
            x = x.view(-1, 128 * 4 * 4)
            x = F.relu(self.fc1_model6(x))
            x = torch.sigmoid(self.fc2(x))
            
        return x

## Models Training & Evaluation

<font size="3"> 
This function evaluates a model's performance on a validation set by calculating key metrics: accuracy, precision, recall, and F1 score. It processes the validation data in batches, generating predictions, and then converts these into binary outcomes based on a specified threshold.
</font>

In [None]:
def calculate_metrics(model, device, val_loader, threshold=0.5):
    model.eval()
    predicted_labels = []
    true_labels = []
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)

            predicted_labels.extend(outputs.cpu().detach().numpy().flatten())
            true_labels.extend(labels.cpu().detach().numpy())
    
    # Convert predicted probabilities to binary predictions based on threshold
    predicted_labels_binary = (np.array(predicted_labels) >= threshold).astype(np.float32)
    
    # Calculate metrics
    accuracy = accuracy_score(true_labels, predicted_labels_binary)
    precision = precision_score(true_labels, predicted_labels_binary)
    recall = recall_score(true_labels, predicted_labels_binary)
    f1 = f1_score(true_labels, predicted_labels_binary)
    
    return accuracy, precision, recall, f1


<font size="3">
This function orchestrates the training process for a given model over a specified number of epochs, optionally performing validation. Here's a breakdown of its steps and objectives:
<br>
<ol>
<li>Initialization: Sets up the optimizer (Adam) with a given learning rate, loss function (Binary Cross Entropy Loss), and moves the model to the designated computing device.</li>
<li>Training Loop</li>
<li>Validation (Optional): If validation mode is enabled, evaluates the model on the validation set after each training epoch using the calculate_metrics function, then prints and records these metrics.</li>
<li>Metrics Tracking: Accumulates training and, if enabled, validation metrics over all epochs to monitor performance trends.</li>
<li>Output: Returns the trained model and a dictionary of metrics for further analysis.</li>
</ol>
<br>

</font>

In [None]:
def train_model(model, device, train_loader, val_loader, learning_rate, num_epochs, validation_mode):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    loss_fn = nn.BCELoss()
    model.to(device)

    train_loss_ls, train_accuracy_ls, train_precision_ls, train_recall_ls, train_f1_ls = [], [], [], [], []
    eval_accuracy_ls, eval_precision_ls, eval_recall_ls, eval_f1_ls = [], [], [], []
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in tqdm(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = loss_fn(outputs, labels.view(-1, 1).float())

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)
        
        train_accuracy, train_precision, train_recall, train_f1 = calculate_metrics(model, device, train_loader)
        print("Epoch {}, Train || Cross Entropy Loss: {:.3f}".format(epoch+1, epoch_loss))
        print("Epoch {}, Train || Accuracy: {:.3f}, Precision: {:.3f}, Recall: {:.3f}, F1-Score: {:.3f}".format(epoch+1, train_accuracy, train_precision, train_recall, train_f1))
        train_loss_ls.append(epoch_loss)
        train_accuracy_ls.append(train_accuracy)
        train_precision_ls.append(train_precision)
        train_recall_ls.append(train_recall)
        train_f1_ls.append(train_f1)
                                                                                                                
        if validation_mode:
            eval_accuracy, eval_precision, eval_recall, eval_f1 = calculate_metrics(model, device, val_loader)
            print("Epoch {}, Evaluation || Accuracy: {:.3f}, Precision: {:.3f}, Recall: {:.3f}, F1-Score: {:.3f}".format(epoch+1, eval_accuracy, eval_precision, eval_recall, eval_f1))        
            eval_accuracy_ls.append(eval_accuracy)
            eval_precision_ls.append(eval_precision)
            eval_recall_ls.append(eval_recall)
            eval_f1_ls.append(eval_f1)
        
        metrics = {'train_loss': train_loss_ls,'train_accuracy': train_accuracy_ls, 'train_precision': train_precision_ls,
                   'train_recall': train_recall_ls, 'train_f1': train_f1_ls, 'eval_accuracy': eval_accuracy_ls,
                   'eval_precision': eval_precision_ls, 'eval_recall': eval_recall_ls, 'eval_f1': eval_f1_ls}
        
    return model, metrics 


<font size="3"> 
This function generates a classification report for a given model on a the test dataset. It evaluates the model's performance by comparing its predictions against the true labels, providing detailed insights into metrics like precision, recall, and F1-score for each class.
</font>

In [None]:
def compute_classification_report(model, test_loader, device='cpu'):
    model.eval()
    model.to(device)
    
    true_labels = []
    pred_labels = []
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            
            _, preds = torch.max(outputs, 1)
            
            true_labels.extend(labels.cpu().numpy())
            pred_labels.extend(preds.cpu().numpy())
    
    # Compute the confusion matrix
    report = classification_report(true_labels, pred_labels, output_dict=False)
    return report


<font size="3"> 
This function initializes the weights of convolutional and linear layers in a neural network using Xavier uniform initialization for weights and setting biases to zero, ensuring optimal starting points for training.
</font>

In [None]:
def init_weights(m):
    if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)

<font size="3"> 
This function saves a dictionary to a JSON file at the specified path and filename.
</font>

In [None]:
def save_dict(path, filename, data):
    with open(path + '/' + filename + '.json', 'w') as f:
        json.dump(data, f)

<font size="3"> 
This function loads a dictionary from a JSON file using the given path and filename.
</font>

In [None]:
def load_dict(path, filename):
    with open(path + '/' + filename + '.json', 'r') as f:
        data_loaded = json.load(f)
    return data_loaded

<font size="3"> 
This function saves a model's state dictionary to a specified file, facilitating model persistence and later retrieval.</font>

In [None]:
def save_model(model, filename):
    torch.save({
        'model_state_dict': model.state_dict(),
    }, filename)

<font size="3"> 
This function loads a model's state dictionary from a specified file, updating the model for further use or evaluation.
</font>

In [None]:
def load_model(model, filename):
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    return model

## Interpretability

<font size="3">
This function calculates the Integrated Gradients for a single image from the test loader to interpret a model's predictions, specifically highlighting the importance of input features (pixels) in the prediction of a class (real or fake). Here are the core steps:
<br>
<ol>
<li>Fetch a Batch: Retrieves a batch of images and labels from the test loader and focuses on the first image and label.</li>
<li>Prepare the Input: Puts the selected image in a tensor suitable for the model (adds a batch dimension).</li>
<li>Model Evaluation: Switches the model to evaluation mode and computes the output for the input image</li>
<li>Predicted Probabilities: Applies the sigmoid function to the model's output to get the predicted probabilities for being real and fake.</li>
<li>Class Name Resolution: Identifies the class name of the target label using the class_to_label mapping.</li>
<li>Integrated Gradients Calculation:Generates attributions using Integrated Gradients, indicating the importance of each pixel for the model's prediction.</li>
<li>Result Compilation: Constructs a string summarizing the model's prediction and prepares the attributions for visualization.</li>
</ol>
<br>

</font>

In [None]:
def calculate_integrated_gradients(model, test_loader, class_to_label, index_image):
    # Fetch a batch of images and labels
    data_iter = iter(test_loader)
    images, labels = next(data_iter)

    # Select the first image in the batch
    input_tensor = images[index_image].unsqueeze(0) 
    target_label = labels[index_image].item() 

    model.eval()
    with torch.no_grad():  
        output = model(input_tensor)  

    predicted_prob = torch.sigmoid(output) 
    # Calculate the probability of the opposite class
    predicted_prob_fake = 1 - predicted_prob.item()
    predicted_prob_real = predicted_prob.item()
    # Find the class names
    target_class_name = [name for name, label in class_to_label.items() if label == target_label][0]
    result = f"Target class:{target_class_name} \n The model predicts REAL:{predicted_prob_real:.2f}% & FAKE:{predicted_prob_fake:.2f}%"

    # Initialize Integrated Gradients
    integrated_gradients = IntegratedGradients(model)
    # Generate attributions
    attributions = integrated_gradients.attribute(input_tensor, target=0, n_steps=50)
    attributions_np = attributions.detach().cpu().numpy()[0]
    attributions_sum = attributions_np.sum(axis=0)
    attributions_sum = attributions_sum / np.max(np.abs(attributions_sum))
    return result, attributions_sum, input_tensor


## Results Plotting 

<font size="3"> 
This function smooths a series of data points using exponential moving average, reducing noise for clearer trend visualization.
</font>

In [None]:
def smooth_curve(points, factor=0.6):
    smoothed_points = []
    for point in points:
        if smoothed_points:
            previous = smoothed_points[-1]
            smoothed_points.append(previous * factor + point * (1 - factor))
        else:
            smoothed_points.append(point)
    return smoothed_points


<font size="3"> 
This function plots a smoothed version of a given metric over epochs, saves the plot as a PDF, and displays it, aiding in performance analysis.
</font>

In [None]:
def plot_single_metric(metric, metric_label, metrics_plot_path, set_name, title):
    smooth_metric = smooth_curve(metric)
    plt.plot(range(1, len(smooth_metric) + 1), smooth_metric, label=set_name)
    plt.title(title)
    plt.xlabel('Epochs')
    plt.ylabel(metric_label)
    plt.legend()
    plt.savefig(metrics_plot_path + '/' + set_name + '_' + metric_label + '.pdf')
    plt.show()
    

<font size="3"> 
This function plots and saves a comprehensive view of training and evaluation metrics (accuracy, precision, recall, F1-score) over epochs, using subplots for clear comparison and analysis of model performance trends.
</font>

In [None]:
def plot_all_metrics_train_val(metrics, plot_name, metrics_plot_path, evaluation_mode, title):
    smooth_accuracy_train = smooth_curve(metrics['train_accuracy'])
    smooth_accuracy_eval = smooth_curve(metrics['eval_accuracy'])
    smooth_precision_train = smooth_curve(metrics['train_precision'])
    smooth_precision_eval = smooth_curve(metrics['eval_precision'])
    smooth_recall_train = smooth_curve(metrics['train_recall'])
    smooth_recall_eval = smooth_curve(metrics['eval_recall'])
    smooth_f1_train = smooth_curve(metrics['train_f1'])
    smooth_f1_eval = smooth_curve(metrics['eval_f1'])
    
    fig, axs = plt.subplots(nrows=2, ncols=2,figsize=(12, 8))
    fig.suptitle(title, fontsize=16)
    # Plot the Accuracy metric on the top-left subplot
    axs[0, 0].plot(range(1, len(smooth_accuracy_train) + 1), smooth_accuracy_train, label='Train')
    axs[0, 0].plot(range(1, len(smooth_accuracy_eval) + 1), smooth_accuracy_eval, label=evaluation_mode)
    axs[0, 0].set_xlabel('Epochs')
    axs[0, 0].set_ylabel('Accuracy')
    axs[0, 0].set_title('Accuracy')
    axs[0, 0].legend()
    # Plot the Precision metric on the top-right subplot
    axs[0, 1].plot(range(1, len(smooth_precision_train) + 1), smooth_precision_train, label='Train')
    axs[0, 1].plot(range(1, len(smooth_precision_eval) + 1), smooth_precision_eval, label=evaluation_mode)
    axs[0, 1].set_xlabel('Epochs')
    axs[0, 1].set_ylabel('Precision')
    axs[0, 1].set_title('Precision')
    axs[0, 1].legend()
    # Plot the Recall metric on the bottom-left subplot
    axs[1, 0].plot(range(1, len(smooth_recall_train) + 1), smooth_recall_train, label='Train')
    axs[1, 0].plot(range(1, len(smooth_recall_eval) + 1), smooth_recall_eval, label=evaluation_mode)
    axs[1, 0].set_xlabel('Epochs')
    axs[1, 0].set_ylabel('Recall')
    axs[1, 0].set_title('Recall')
    axs[1, 0].legend()
    # Plot the F1-Score metric on the bottom-right subplot
    axs[1, 1].plot(range(1, len(smooth_f1_train) + 1), smooth_f1_train, label = 'Train')
    axs[1, 1].plot(range(1, len(smooth_f1_eval) + 1), smooth_f1_eval, label = evaluation_mode)
    axs[1, 1].set_xlabel('Epochs')
    axs[1, 1].set_ylabel('F1-Score')
    axs[1, 1].set_title('F1-Score')
    axs[1, 1].legend()
    # Adjust the spacing between subplots
    plt.tight_layout(rect=[0, 0.03, 1, 0.98])
    # Show the plot
    plt.savefig(metrics_plot_path + '/' + plot_name + '_All_Metrics' + '.pdf')
    plt.show()
    

<font size="3"> 
This function visualizes an original image alongside its attribution heatmap to interpret the model's focus, and saves the combined plot for review.
</font>

In [None]:
def plot_interpretability(interpretability_results_path, result, attributions_sum, input_tensor, index_image):

    input_image_np = input_tensor.detach().cpu().numpy()[0] 
    fig, ax = plt.subplots(1, 2, figsize=(8, 4)) 
    
    # Visualize the image and the attributions
    ax[0].imshow(np.transpose(input_image_np, (1, 2, 0)))
    ax[0].axis('off')
    ax[0].set_title('Original Image', fontsize=9)
    ax[1].imshow(attributions_sum, cmap='jet')
    ax[1].axis('off')
    ax[1].set_title('Attribution Heatmap', fontsize=9)
    

    plt.tight_layout(pad=5.0)
    plt.suptitle(result, fontsize=11, verticalalignment='top')
    plt.savefig(interpretability_results_path + '/' + 'interpretability' + str(index_image) + '.pdf', bbox_inches='tight')  # This option reduces the padding around the saved figure
    plt.show()
    

<font size="3">
This function generates an overlayed visualization by combining the original image with a heatmap based on Integrated Gradients attributions, highlighting regions influencing the model's prediction. Here are the core steps:
<br>
<ol>
<li>Normalize and prepare the input image and attributions for visualization, converting tensor data to NumPy arrays and applying normalization to both.</li>
<li>Threshold attributions to identify key regions influencing the model's predictions, creating a binary mask for significant areas.</li>
<li>Generate a heatmap from the normalized attributions using a colormap, emphasizing areas of interest.</li>
<li>Resize and mask the heatmap to match the original image's dimensions and apply the binary mask, focusing on significant regions.</li>
<li>Overlay the heatmap onto the original image with adjusted transparency, blending significant attribution areas with the original visual content.</li>
<li>Display and save the combined overlay image, offering a visual interpretation of the model's decision-making areas.</li>
</ol>
<br>

</font>

In [None]:
def plot_interpretability_overlayed_image(interpretability_results_path, result, attributions_sum, input_tensor, index_image, threshold=0.8):
    # Convert the input tensor to numpy for visualization
    input_image_np = input_tensor.detach().cpu().numpy()[0]  # Assuming the input is (1, C, H, W)
    input_image_np = np.transpose(input_image_np, (1, 2, 0))  # Convert to H, W, C for visualization
    
    # Normalize the input image
    input_image_np -= input_image_np.min()
    input_image_np /= input_image_np.max()
    
    # Normalize the attribution sum to 0-1 for the heatmap
    attributions_norm = (attributions_sum - attributions_sum.min()) / (attributions_sum.max() - attributions_sum.min())
    
    # Apply threshold to attributions
    mask = attributions_norm > threshold
    # Expand mask for RGB dimensions
    mask = np.stack([mask, mask, mask], axis=2)
    
    # Create heatmap
    heatmap = cm.jet(attributions_norm)[:, :, :3]  # Use the jet colormap, take the first 3 channels for RGB

    # Resize heatmap to match the input image dimensions if necessary
    if heatmap.shape != input_image_np.shape:
        zoom_factors = (input_image_np.shape[0] / heatmap.shape[0], input_image_np.shape[1] / heatmap.shape[1], 1)
        heatmap = zoom(heatmap, zoom_factors, order=1)
    
    # Apply the mask to the heatmap
    heatmap_masked = np.where(mask, heatmap, np.zeros_like(heatmap))
    
    # Overlay the masked heatmap on the original image
    overlayed_image = (1 * input_image_np + 0.4 * heatmap_masked).clip(0, 1)  # Adjust alpha to control the transparency
    
    # Display the overlayed image
    plt.figure(figsize=(8, 4))
    plt.imshow(overlayed_image)
    plt.axis('off')
    plt.title(result, fontsize=10)
    plt.savefig(interpretability_results_path + '/' +'interpretability_overlayed_heatmap' + str(index_image) + '.pdf')
    plt.show()


<h1>
<center>Execution</center>
</h1>

<font size="3"> 
Global Variables
</font>

In [None]:
img_height = 128
img_width = 128
input_channels = 3
batch_size = 64
train_val_split_ratio = 0.8

## Execution: Data Loading & Analysis

In [None]:
train_dataset, test_dataset, class_names, class_to_label = load_datasets(train_data_path, test_data_path, img_height, img_width, denoising=False)

#plot_class_balance(train_dataset, test_dataset, class_names)
plot_random_images(train_dataset, class_names, "Train Set")
print("Class to label mapping:", class_to_label)

## Execution: Tunning Using Validation Set

<font size="3"> 
Local Variables
</font>

In [None]:
denoising = True
use_validation_set = True
num_epochs = 20
learning_rate = 0.005
set_name = 'Validation'

networks = ['2Conv1Pool','3Conv2Pool', '3Conv3Pool', '4Conv4Pool',
            '4Conv4Pool_BatchNorm','4Conv4Pool_BatchNorm_Dropout']

final_arcitecture = '4Conv4Pool'
learning_rate_ls = [0.0001, 0.001, 0.005]

<font size="3"> 
Data Loading
</font>

In [None]:
train_dataset, test_dataset, class_names, class_to_label = load_datasets(train_data_path, test_data_path, img_height, img_width, denoising)
train_loader, val_loader, test_loader = get_data_loaders(train_dataset, test_dataset, batch_size, train_val_split_ratio, use_validation_set)

get_loader_shapes(train_loader, "Train Set")

### Experiments with Different CNN Architectures

In [None]:
for network in networks:
    print(f"Training of {network}")
    model = CNN_Model(network, input_channels)
    model.apply(init_weights)
    model, metrics_val = train_model(model, device, train_loader, val_loader, learning_rate, num_epochs, validation_mode=True)
    results_name = network + '_' + set_name
    save_dict(results_path, results_name, metrics_val)  

<font size="3"> 
Analize and Plot Different CNN Results
</font>

In [None]:
for network in networks:
    title = f"Model: {network}"
    results_name = network + '_' + set_name
    metrics_val = load_dict(results_path, results_name)
    plot_all_metrics_train_val(metrics_val, results_name, metrics_plot_path, set_name, title)

### Experiments with Different Learning Rate

In [None]:
for learning_rate_value in learning_rate_ls:
    print(f"Training of {final_arcitecture} with Learning Rate: {learning_rate_value}")
    model = CNN_Model(final_arcitecture, input_channels)
    model.apply(init_weights)
    model, metrics_val = train_model(model, device, train_loader, val_loader, learning_rate_value, num_epochs, validation_mode=True)
    results_name = final_arcitecture + '_lr_' + str(learning_rate_value) + '_' + set_name
    save_dict(results_path, results_name, metrics_val)

<font size="3"> 
Analize and Plot Learning Rate Influence
</font>

In [None]:
for learning_rate_value in learning_rate_ls:
    title = f"Model: {final_arcitecture} with LR: {learning_rate_value}"
    results_name = final_arcitecture + '_lr_' + str(learning_rate_value) + '_' + set_name
    metrics_val = load_dict(results_path, results_name)
    plot_all_metrics_train_val(metrics_val, results_name, metrics_plot_path, set_name, title)

## Execution: Final Model Training 

<font size="3"> 
Local Variables
</font>

In [None]:
denoising = True
set_name = 'Test'
final_arcitecture = '4Conv4Pool'
results_name = final_arcitecture + '_' + set_name
num_epochs = 8
learning_rate = 0.001
use_validation_set = False

<font size="3"> 
Data Loading
</font>

In [None]:
train_dataset, test_dataset, class_names, class_to_label = load_datasets(train_data_path, test_data_path, img_height, img_width, denoising)
train_loader, val_loader, test_loader = get_data_loaders(train_dataset, test_dataset, batch_size, train_val_split_ratio, True)

<font size="3"> 
Models Trainning
</font>

In [None]:
model = CNN_Model(final_arcitecture, input_channels)
model.apply(init_weights)
model, metrics_test = train_model(model, device, train_loader, test_loader, learning_rate, num_epochs, validation_mode=True)

<font size="3"> 
Save Model and Results
</font>

In [None]:
save_model(model, saved_models_path + '/' + final_arcitecture + '.pth')
save_dict(results_path, results_name, metrics_test)

## Execution: Evaluation on Test Set

<font size="3"> 
Local Variables
</font>

In [None]:
denoising = True
set_name = 'Test'
final_arcitecture = '4Conv4Pool'
use_validation_set = False
results_name = final_arcitecture + '_' + set_name

<font size="3"> 
Data Loading 
</font>

In [None]:
train_dataset, test_dataset, class_names, class_to_label = load_datasets(train_data_path, test_data_path, img_height, img_width, denoising)
train_loader, val_loader, test_loader = get_data_loaders(train_dataset, test_dataset, batch_size, train_val_split_ratio, use_validation_set)

<font size="3"> 
Metrics Calculation
</font>

In [None]:
model = CNN_Model(final_arcitecture, input_channels)
model = load_model(model, saved_models_path + '/' + final_arcitecture + '.pth')
test_accuracy, test_precision, test_recall, test_f1 = calculate_metrics(model, device, test_loader)
print("Evaluation on Test || Accuracy: {:.3f}, Precision: {:.3f}, Recall: {:.3f}, F1-Score: {:.3f}".format(test_accuracy, test_precision, test_recall, test_f1))

<font size="3"> 
Plot Test Metrics
</font>

In [None]:
metrics_test = load_dict(results_path, results_name)
title = f"Model: {final_arcitecture}"
plot_all_metrics_train_val(metrics_test, results_name, metrics_plot_path, set_name, title)
plot_single_metric(metrics_test['eval_f1'], 'F1-Score', metrics_plot_path, set_name, title)

## Execution: Interpretability

<font size="3"> 
Local Variables
</font>

In [None]:
denoising = True
final_arcitecture = '4Conv4Pool'
use_validation_set = False

<font size="3"> 
Data Loading 
</font>

In [None]:
train_dataset, test_dataset, class_names, class_to_label = load_datasets(train_data_path, test_data_path, img_height, img_width, denoising)
train_loader, val_loader, test_loader = get_data_loaders(train_dataset, test_dataset, batch_size, train_val_split_ratio, True)


<font size="3"> 
Calculate Integrated Gradients
</font>

In [None]:
index_image = 30
model = CNN_Model(final_arcitecture, input_channels)
model = load_model(model, saved_models_path + '/' + final_arcitecture + '.pth')
result, attributions_sum, input_tensor = calculate_integrated_gradients(model, test_loader, class_to_label, index_image)


<font size="3"> 
Plot Interpretability
</font>

In [None]:
plot_interpretability(interpretability_results_path, result, attributions_sum, input_tensor, index_image)
plot_interpretability_overlayed_image(interpretability_results_path, result, attributions_sum, input_tensor, index_image, threshold=0.6)


## Experimental Evaluaiton Plots

<font size="3"> 
Model Architecture
</font>

In [None]:
m1 = load_dict(results_path, '2Conv1Pool_Validation')
m2 = load_dict(results_path, '3Conv2Pool_Validation')
m3 = load_dict(results_path, '3Conv3Pool_Validation')
m4 = load_dict(results_path, '4Conv4Pool_Validation')
m5 = load_dict(results_path, '4Conv4Pool_BatchNorm_Validation')
m6 = load_dict(results_path, '4Conv4Pool_BatchNorm_Dropout_Validation')

smooth_mae_history1 = smooth_curve(m1['eval_f1'][0:])
smooth_mae_history2 = smooth_curve(m2['eval_f1'][0:])
smooth_mae_history3 = smooth_curve(m3['eval_f1'][0:])
smooth_mae_history4 = smooth_curve(m4['eval_f1'][0:])
smooth_mae_history5 = smooth_curve(m5['eval_f1'][0:])
smooth_mae_history6 = smooth_curve(m6['eval_f1'][0:])

plt.plot(range(1, len(smooth_mae_history1) + 1), smooth_mae_history1,label='2Conv1Pool')
plt.plot(range(1, len(smooth_mae_history2) + 1), smooth_mae_history2,label='3Conv2Pool')
plt.plot(range(1, len(smooth_mae_history3) + 1), smooth_mae_history3,label='3Conv3Pool')
plt.plot(range(1, len(smooth_mae_history4) + 1), smooth_mae_history4,label='4Conv4Pool')
plt.plot(range(1, len(smooth_mae_history5) + 1), smooth_mae_history5,label='4Conv4PoolBN')
plt.plot(range(1, len(smooth_mae_history6) + 1), smooth_mae_history6,label='4Conv4PoolBNDR')

plt.title('Model Architecture')
plt.xlabel('Epochs')
plt.ylabel('F1-Score')
plt.legend()
#plt.savefig(experimental_evaluation_path + "/Model_Architecture.pdf")
plt.show()

<font size="3"> 
Learning Rate Infuence
</font>

In [None]:
m1 = load_dict(results_path, '4Conv4Pool_lr_0.0001_Validation')
m2 = load_dict(results_path, '4Conv4Pool_lr_0.001_Validation')
m3 = load_dict(results_path, '4Conv4Pool_lr_0.005_Validation')

smooth_mae_history1 = smooth_curve(m1['eval_f1'][0:])
smooth_mae_history2 = smooth_curve(m2['eval_f1'][0:])
smooth_mae_history3 = smooth_curve(m3['eval_f1'][0:])

plt.plot(range(1, len(smooth_mae_history1) + 1), smooth_mae_history1,label='LR: 0.0001')
plt.plot(range(1, len(smooth_mae_history2) + 1), smooth_mae_history2,label='LR: 0.001')
plt.plot(range(1, len(smooth_mae_history3) + 1), smooth_mae_history3,label='LR: 0.005')


plt.title('Learning Rate Infuence')
plt.xlabel('Epochs')
plt.ylabel('F1-Score')
plt.legend()
plt.savefig(experimental_evaluation_path + "/Learning_Rate_Infuence.pdf")
plt.show()

<font size="3"> 
Validation vs Training: F1-Score
</font>

In [None]:
m1 = load_dict(results_path, '4Conv4Pool_Validation')


smooth_mae_history1 = smooth_curve(m1['eval_f1'][0:])
smooth_mae_history2 = smooth_curve(m1['train_f1'][0:])

plt.plot(range(1, len(smooth_mae_history1) + 1), smooth_mae_history1,label='Validation')
plt.plot(range(1, len(smooth_mae_history2) + 1), smooth_mae_history2,label='Training')

plt.title('Validation vs Training: F1-Score')
plt.xlabel('Epochs')
plt.ylabel('F1-Score')
plt.legend()
plt.savefig(experimental_evaluation_path + "/Validation_vs_Training.pdf")
plt.show()

<font size="3"> 
Wavelet Denoising Infuence
</font>

In [None]:
m1 = load_dict(results_path, '4Conv4Pool_Test_noise')
m2 = load_dict(results_path, '4Conv4Pool_Test')

smooth_mae_history1 = smooth_curve(m1['eval_f1'][0:])
smooth_mae_history2 = smooth_curve(m2['eval_f1'][0:])

plt.plot(range(1, len(smooth_mae_history1) + 1), smooth_mae_history1,label='Raw Images')
plt.plot(range(1, len(smooth_mae_history2) + 1), smooth_mae_history2,label='Wavelet Denoising')

plt.title('Wavelet Denoising Infuence')
plt.xlabel('Epochs')
plt.ylabel('F1-Score')
plt.legend()
plt.savefig(experimental_evaluation_path + "/Wavelet_Denoising.pdf")
plt.show()