# ***SpatialTemporal-RISE***

In this notebook, we delve into the ST-RISE framework, an extension of the original RISE method, which aims to provide interpretability for machine learning models in the context of spatiotemporal data. The ST-RISE approach focuses on understanding the influence of input features over time and space, particularly in applications where temporal dynamics play a critical role.

## Objectives

The primary objectives of ST-RISE are:

1. **Temporal Sensitivity**: To analyze how model predictions vary with changes in input features across different time steps.
2. **Spatial Relevance**: To evaluate the significance of spatial patterns in influencing model outputs, enabling a better understanding of local feature importance.
3. **Integration of Temporal and Spatial Information**: To combine the insights from both spatial and temporal perspectives, providing a comprehensive view of the model's behavior.

## Methodology

The ST-RISE method builds upon the existing RISE framework by incorporating a two-dimensional Gaussian noise generation process that considers both spatial and temporal aspects of the input data. This allows for the creation of saliency maps that highlight the most influential features at each time step while accounting for their spatial relationships.

In the following sections, we will explore the implementation details, conduct experiments using the ST-RISE framework, and analyze the results to assess the effectiveness of this method in providing insights into model predictions.

## Structure of the Notebook

- **Data Preparation**: Loading and preprocessing the spatiotemporal dataset.
- **ST-RISE Implementation**: Detailed steps to implement the ST-RISE methodology.
- **Results Analysis**: Visualizing and interpreting the saliency videos generated by the ST-RISE framework.
- 4. **Evaluation** of the saliency videos using **Insertion** and **Deletion** metrics.

By the end of this notebook, we aim to enhance our understanding of model interpretability in spatiotemporal contexts and demonstrate the utility of the ST-RISE approach in achieving this goal.


In [None]:
# Dependencies

# Specify the TensorFlow version used
# TensorFlow 2.15.0 is chosen for compatibility with Keras and other dependencies.
!pip install tensorflow==2.15.0  # cuDNN 8.9, CUDA 12.2
!pip install keras==2.15.0  # Keras for high-level neural networks
# General libraries for data manipulation and numerical computations
!pip install numpy==1.25.2  # Fundamental package for numerical computing
!pip install pandas==2.0.3  # Data manipulation and analysis
# Libraries for handling spatial and raster data
!pip install rioxarray==0.15.5  # Raster I/O with xarray
!pip install rasterio==1.3.10  # Reading and writing raster datasets
!pip install geopandas==0.13.2  # Geospatial data manipulation
# Libraries for data visualization
!pip install matplotlib==3.7.1  # Comprehensive library for static, animated, and interactive visualizations
!pip install seaborn==0.13.1  # Statistical data visualization based on Matplotlib
!pip install matplotlib_scalebar==0.8.1  # Adds a scale bar to Matplotlib plots
!pip install adjustText==1.1.1  # Helps adjust text in Matplotlib for better readability
# Libraries for scientific computing and handling NetCDF data
!pip install xarray==2023.7.0  # N-D array support with labeled axes
!pip install netCDF4==1.6.5  # Reading and writing NetCDF files
!pip install cftime==1.6.3  # Provides time handling capabilities
# Other useful libraries
!pip install imageio==2.31.6  # For reading and writing image data
!pip install cmasher==1.8.0  # Colormaps for Matplotlib
!pip install great-circle-calculator==1.3.1  # Calculate great circle distances between points
!pip install geopy==2.4.1  # For geocoding and distance calculations
!pip install scipy==1.11.4  # Scientific and technical computing

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import activations
import sys
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.models import load_model
from keras import activations
import numpy as np
import copy

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/gdrive')
%cd /gdrive

#### ***Loading Data from Google Drive***

In [None]:
#### ***Loading Data from Google Drive***

# Define the base path for data storage in Google Drive
# Change the base_path to point to the correct directory where your data is stored
base_path = "./MyDrive/Water_Resources/"

# Specify different subdirectories for various data and model types
data_path = base_path + "data/training_validation_test_splits"  # Path for training, validation, and test data
model_path = base_path + "trained_models/"  # Path for trained models
modules_path = base_path + "python_modules/"  # Path for additional Python modules
xai_path = base_path + "XAI/"  # Path for XAI-related resources
results_path = xai_path + "results/"  # Path for storing results from XAI analysis

In [None]:
# IMPORT DATA FOR VOTTIGNASCO

# Paths to the .npy files containing the data for Vottignasco
v_test_OHE_path = data_path + '/Vottignasco_00425010001_test_month_OHE.npy'  # One-Hot Encoded test data
v_test_image_path = data_path + '/Vottignasco_00425010001_test_normalized_image_sequences.npy'  # Normalized image sequences for testing
v_test_target_dates_path = data_path + '/Vottignasco_00425010001_test_target_dates.npy'  # Target dates for the test data

# Load the NumPy arrays from the specified .npy files
vottingasco_test_OHE = np.load(v_test_OHE_path)  # Load One-Hot Encoded data
vottignasco_test_image = np.load(v_test_image_path)  # Load normalized image sequences
vottignasco_test_dates = np.load(v_test_target_dates_path)  # Load target dates for the test set


In [None]:
print(len(vottignasco_test_dates))
print(len(vottignasco_test_image))
print(len(vottingasco_test_OHE))

105
105
105


In [None]:
# Visualize the data for Vottignasco

# Display the One-Hot Encoded representation of the first instance in the test set
print(vottingasco_test_OHE[0], "\n")  # One-Hot Encoding of seasons for the first test instance

# Display the normalized image of the first instance in the test set
print(vottignasco_test_image[0][0], "\n")  # Normalized image sequence for the first test instance


[[1. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]] 

[[[-0.55874205 -0.63231546 -1.0530392 ]
  [-0.62497324 -0.24734885 -0.95617104]
  [-0.64852315 -0.29456472 -1.0120971 ]
  [-0.6475669   0.01120646 -0.9414161 ]
  [-0.6483382   0.04486901 -0.9291374 ]
  [-0.64837635  0.21673024 -0.7445328 ]
  [-0.6482272   0.01929929 -0.9028175 ]
  [-0.64826417 -0.35680774 -0.71746933]]

 [[-0.56017417 -0.30140942 -1.0722721 ]
  [-0.6341238  -0.3249078  -1.1285268 ]
  [-0.64822614 -0.2947506  -0.9674306 ]
  [-0.6474924   0.08621608 -0.8394658 ]
  [-0.6481902   0.17821798 -0.8374907 ]
  [-0.64775074  0.14878109 -0.8802751 ]
  [-0.6477138   0.09261598 -1.0055645 ]
  [-0.64804226  0.06027458 -1.1226149 ]]

 [[-0.64782256 -0.32245472 -1.1192324 ]
  [-0.64856017 -0.17655824 -0.96336424]
  [-0.64591223 -0.43483806 -1.0460285 ]
  [-0.6466841   0.08755382 -0.7752676 ]
  [-0.6481902   0.13593481 -0.814374  

#### ***Models Loading***

In [None]:
from keras.models import load_model

# If you want to load the entire model, dropout_custom layer has to be defined:

mc_dropout = True

# Definizione della classe personalizzata doprout_custom
class doprout_custom(tf.keras.layers.SpatialDropout1D):
    def call(self, inputs, training=None):
        if mc_dropout:
            return super().call(inputs, training=True)
        else:
            return super().call(inputs, training=False)

In [None]:
# Find the paths of the ensemble models for Vottignasco in Google Drive
import os

# Define the base directory for the Vottignasco models
base_dir = model_path + "seq2val/Vottignasco"  # Directory containing the models
lstm_suffix = 'time_dist_LSTM'  # Suffix used to identify LSTM model files

# Initialize lists to store model and weight paths
vott_lstm_models = []  # List to store paths of LSTM models
vott_lstm_weights = []  # List to store paths of model weights

def extract_index(filename):
    """Function to extract the final index from the filename."""
    return int(filename.split('_')[-1].split('.')[0])  # Extracts the numeric index from the filename

# Traverse the directory and add relevant files to their respective lists
for root, _, files in os.walk(base_dir):
    for filename in files:
        full_path = os.path.join(root, filename)  # Get the full path of the file
        if lstm_suffix in filename:  # Check if the file is an LSTM model
            if filename.endswith(".keras"):  # Check for model files
                vott_lstm_models.append(full_path)  # Add model path to the list
            else:  # Otherwise, it's a weights file
                vott_lstm_weights.append(full_path)  # Add weights path to the list

# Sort the model and weight paths based on the extracted index from the filenames
vott_lstm_models = sorted(vott_lstm_models, key=lambda x: extract_index(os.path.basename(x)))  # Sort models
vott_lstm_weights = sorted(vott_lstm_weights, key=lambda x: extract_index(os.path.basename(x)))  # Sort weights


In [None]:
# List of paths for the loaded models and weights
vott_lstm_models_loaded = []  # List to store the loaded Vottignasco LSTM models
racc_lstm_models_loaded = []  # (Assuming this will be used later for another set of models)

# Loop to load the LSTM models
for i in range(10):
    print(f"Loading LSTM models {i+1}")  # Indicate the loading progress of LSTM models

    # Load VOTTIGNASCO model
    model_lstm_path = vott_lstm_models[i]  # Get the path for the current LSTM model
    # Load the CNN+LSTM model using the specified path
    model = load_model(model_lstm_path, custom_objects={"doprout_custom": doprout_custom})  # Load the model with custom objects
    # Append the loaded model to the list of loaded models
    vott_lstm_models_loaded.append(model)  # Add the loaded model to the list


In [None]:
vott_lstm_models_loaded

[<keras.src.engine.functional.Functional at 0x7903b227f9a0>,
 <keras.src.engine.functional.Functional at 0x7903abf13fa0>,
 <keras.src.engine.functional.Functional at 0x7903b227f7c0>,
 <keras.src.engine.functional.Functional at 0x7903abfeff70>,
 <keras.src.engine.functional.Functional at 0x7903a03568f0>,
 <keras.src.engine.functional.Functional at 0x7903a02900a0>,
 <keras.src.engine.functional.Functional at 0x7903b0a1bac0>,
 <keras.src.engine.functional.Functional at 0x7903abfece20>,
 <keras.src.engine.functional.Functional at 0x7903a0099480>,
 <keras.src.engine.functional.Functional at 0x79038c731600>]

## ***SPATIAL-TEMPORAL RISE***

### ***3D Gaussian Noise***

In [None]:
import matplotlib.pyplot as plt

def plot_noise_slice(noise, time_step=0):
    # Assumes that noise has the shape (time_steps, height, width, 3)
    # Extract the corresponding channel for visualization
    noise_slice = noise[time_step, :, :, 0]  # Select the first channel for simplicity
    # Display the extracted noise slice using a colormap
    plt.imshow(noise_slice, cmap='jet', vmin=-1, vmax=+1.0)
    # Set the title of the plot to indicate the time step being visualized
    plt.title(f'Noise at time step {time_step}')
    # Add a colorbar to provide a reference for the color scale
    plt.colorbar()
    # Show the plot
    plt.show()


In [None]:
def plot_mean_noise_over_time(noise):
    # Calculate the mean noise intensity over height, width, and channels
    mean_noise = np.mean(noise, axis=(1, 2, 3))  # Mean across height, width, and channels
    # Plot the mean noise intensity against time steps
    plt.plot(mean_noise)
    # Set the title of the plot
    plt.title('Mean Noise over Time')
    # Label the x-axis as 'Time step'
    plt.xlabel('Time step')
    # Label the y-axis as 'Mean Noise Intensity'
    plt.ylabel('Mean Noise Intensity')
    # Display the plot
    plt.show()


#### ***Generate Masks: Gaussian-3D Noise***

In [None]:
def gaussian_perturbation(shape, center, sigma_t=50, sigma_x=1.0, sigma_y=1.0):
    # Unpack the shape tuple into individual dimensions
    time_steps, height, width, channels = shape
    t_center, h_center, w_center = center

    # Define the mean vector (μ) representing the center of the noise
    mu = np.array([t_center, h_center, w_center])

    # Covariance matrix for the three dimensions
    sigma = np.array([
        [sigma_t, 0, 0],  # Variance along the time dimension
        [0, sigma_x, 0],  # Variance along the height dimension
        [0, 0, sigma_y]   # Variance along the width dimension
    ])

    # Calculate the inverse of the covariance matrix (Σ^−1)
    sigma_inv = np.linalg.inv(sigma)

    # Generate 3D coordinates
    x_coords, y_coords, z_coords = np.meshgrid(
        np.arange(time_steps), np.arange(height), np.arange(width), indexing='ij'
    )

    # Flatten the coordinates for easier calculations
    positions = np.vstack([x_coords.ravel(), y_coords.ravel(), z_coords.ravel()]).T

    # Calculate the multivariate Gaussian for each point
    diff = positions - mu  # Difference from the mean
    exponent = -0.5 * np.sum(diff @ sigma_inv * diff, axis=1)  # Exponent term for the Gaussian

    # Reshape to obtain the original shape (time_steps, height, width)
    noise_3D = np.exp(exponent).reshape(time_steps, height, width)

    return noise_3D  # Return the generated Gaussian noise


In [None]:
def generate_masks_gaussian3D(N, shape, sigma_t, sigma_x, sigma_y):
    # Unpack the shape tuple into individual dimensions
    time_steps, height, width, channels = shape
    # Initialize an array to hold the generated masks
    masks = np.zeros((N, time_steps, height, width))

    # Generate N masks
    for i in tqdm(range(N), desc='Generating masks'):
        beta = random.choice([-1, 1])  # Randomly choose a sign for the noise (positive or negative)

        # Randomly select the center coordinates for the Gaussian perturbation
        t_center = np.random.randint(0, 104)
        h_center = np.random.randint(0, 5)
        w_center = np.random.randint(0, 8)
        center = (t_center, h_center, w_center)

        # Generate 3D Gaussian noise based on the selected center
        noise_3D = gaussian_perturbation(shape, center, sigma_t, sigma_x, sigma_y)

        # Adjust the sign of the noise based on beta
        if (beta == 1):
            noise_3D = np.abs(noise_3D)  # Make all values positive
        else:
            noise_3D = -np.abs(noise_3D)  # Make all values negative

        # Store the generated noise in the masks array
        masks[i] = noise_3D

    return masks  # Return the array of generated masks


In [None]:
# Sum the generated mask to the specified channel at each time step
def additive_gaussianNoise_onechannel(images, masks, index_feat_to_perturb):
    masked = []  # Initialize a list to store the masked images

    # Iterate over all the N generated masks
    for mask in masks:
        masked_images = copy.deepcopy(images)  # Create a deep copy of the original images

        # Perturb only the specified channel
        masked_images[..., index_feat_to_perturb] = np.add(
            masked_images[..., index_feat_to_perturb],
            mask  # Add the mask to the specified channel of the images
        )

        masked.append(masked_images)  # Append the perturbed images to the list

    return masked  # Return the list of masked images

### ***RISE-3D***

In [None]:
def ensemble_predict(models, images, x3_exp):
    # Check the type of images; if it's a list, get the length for the ensemble
    if type(images) == list:
        len_x3 = len(images)  # Number of masked images
    else:
        len_x3 = 1  # Only one image provided
        images = np.expand_dims(images, axis=0)  # Expand dimensions to match expected input shape

    # Prepare the input for the ensemble model
    Y_test = np.stack(images)  # Stack images into a NumPy array
    Y_test_x3 = np.tile(x3_exp, (len_x3, 1, 1))  # Duplicate x3_exp for each image

    # Initialize a list to collect predictions from each model
    all_preds = []

    # Iterate through the models and gather predictions
    for model in models:  # Use the models passed as a parameter
        preds = model.predict([Y_test, Y_test_x3], verbose=0)  # Make predictions
        all_preds.append(preds)  # Store predictions

    # Convert the list of predictions into a NumPy array for easier mean calculation
    all_preds_array = np.array(all_preds)

    # Calculate the mean predictions across all models (along axis 0)
    mean_preds = np.mean(all_preds_array, axis=0)

    return mean_preds  # Return the mean predictions

In [None]:
def calculate_saliency_map(preds_array, masks):
    """
    Calculate the mean saliency map given a series of predictions and masks.

    :param preds_array: Array of predictions (number of masks x prediction dimensions).
    :param masks: Array of masks (number of masks x mask dimensions).
    :return: Mean saliency map.
    """
    sal = []  # Initialize a list to store saliency calculations for each mask
    for j in range(len(masks)):
        sal_i = preds_array[j] * np.abs(masks[j])  # Element-wise multiplication of predictions and the absolute mask
        sal.append(sal_i.reshape(-1, 5, 8))  # Reshape to match the original frame format

    # Remove extra dimensions to allow np.mean along axis=0. masks has shape (N, 5, 8, 1)
    masks_squeezed = np.squeeze(np.abs(masks))
    # Now calculate the mean along axis 0
    ev_masks = np.mean(masks_squeezed, axis=0)

    # Calculate the saliency map as the mean saliency weighted by the inverse of the mean masks
    sal = (1 / ev_masks) * np.mean(sal, axis=0)

    return sal  # Return the computed mean saliency map


In [None]:
def explain(models, x1, x3, generate_masks_func, apply_mask_func,
            index_feat_to_perturb=None, N=1000, sigma_t=100.0, sigma_spatial=1.0, input_size=(104, 5, 8, 3)):
    """
    Implementation of RISE (Randomized Input Sampling for Explanation)

    :param models: List of pretrained models.
    :param x1: Input images.
    :param x3: One-hot encoded input.
    :param generate_masks_func: Function to generate masks.
    :param apply_mask_func: Function to apply masks to images.
    :param index_feat_to_perturb: Indices of features to perturb.
    :param N: Maximum number of masks to generate.
    :param sigma_t: Sigma for temporal dimension in mask generation.
    :param sigma_x: Sigma for spatial dimension in mask generation.
    :param sigma_y: Sigma for spatial dimension in mask generation.
    :param input_size: Size of the input.
    :return: Saliency Video.
    """

    # Original prediction without masks
    pred_original = ensemble_predict(models, x1, x3)  # Average predictions from the ensemble
    #print("original pred:", pred_original)

    # Generate and apply Gaussian masks
    masks = generate_masks_func(N, input_size, sigma_t, sigma_x=sigma_spatial, sigma_y=sigma_spatial)
    masked_images = apply_mask_func(x1, masks, index_feat_to_perturb)

    # Predictions on perturbed images with the ensemble
    preds_masked = ensemble_predict(models, masked_images, x3)
    #print("pred_masked:", preds_masked)

    # Calculate the absolute difference between original and masked predictions
    diff_pred = [abs(pred_original - pred_masked) for pred_masked in preds_masked]
    #print("diff_pred:", diff_pred)

    # Concatenate the difference predictions along the specified axis
    weights_array = np.concatenate(diff_pred, axis=0)

    # Calculate the initial saliency map based on the weights
    sal = calculate_saliency_map(weights_array, masks)

    # The final saliency map is stored in `sal`
    print("Processo completato.", "\n")
    print("Mappa di salienza finale:")
    print(sal)
    return sal  # Return the calculated saliency map


### ***Evaluation: Insertion and Deletion***

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error

def calculate_auc(x, y):
    """
    Calcola l'area sotto la curva (AUC) utilizzando il metodo del trapezio.

    :param x: Valori dell'asse x (frazione dei pixel inseriti).
    :param y: Valori dell'asse y (errori calcolati).
    :return: Area sotto la curva.
    """
    return np.trapz(y, x)

In [None]:
import numpy as np

# Returns an array structured as [value, [t, x, y]] with pixels ordered by importance
def get_flatten_saliency_video_ordered_by_importance(saliency_video):
    # Flatten the saliency video to a 1D array
    flatten_saliency_video = saliency_video.flatten()
    # Get the indices of non-zero pixels from the original saliency video
    indices = np.argwhere(saliency_video)

    # Create a list of tuples containing the saliency value and the corresponding indices
    saliency_video_value_with_indices = [
        (flatten_saliency_video[i], indices[i])
        for i in range(0, len(flatten_saliency_video))
    ]  # List of [saliency_value_of_pixel, [t, x, y]]

    # Sort the list by saliency value in ascending order, using the indices as a secondary criterion
    sorted_saliency_video_for_importance_with_indices = sorted(
        saliency_video_value_with_indices,
        key=lambda x: (x[0], -x[1])  # Sort by saliency_value_of_pixel
    )

    # Return the sorted list in descending order of importance
    return sorted_saliency_video_for_importance_with_indices[::-1]

#### ***Insertion***

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
import copy

def update_image_with_important_pixel(image, initial_image, frame, x, y):
    """
    Update the image by inserting the most important pixel.

    :param image: Current image.
    :param initial_image: Original image.
    :param frame: Frame number to insert into the image.
    :param x: Spatial x-coordinate of the pixel.
    :param y: Spatial y-coordinate of the pixel.
    :return: Updated image.
    """
    new_image = copy.deepcopy(image)
    new_image[frame][x, y] = initial_image[frame][x, y]  # Corrected the operator
    return new_image

def insertion(models, original_images, x3, pixel_sorted_by_saliency_value_with_indices, initial_blurred_images):
    """
    Calculate the insertion metric for a given explanation.

    :param models: List of pretrained models.
    :param original_images: Original image.
    :param x3: One-hot encoding for prediction.
    :param pixel_sorted_by_saliency_value_with_indices: Indices of pixels in order of importance.
    :param initial_blurred_images: Initial image with all pixels set to zero.
    :return: List of errors at each insertion step.
    """

    # Original prediction
    original_prediction = ensemble_predict(models, original_images, x3)[0]
    print("Original prediction:", original_prediction)

    # List to store images with gradually added frames
    insertion_images = [initial_blurred_images.copy()]

    # Prediction on the initial image (all pixels blurred)
    I_prime = initial_blurred_images.copy()

    # Gradually add the most important pixels frame by frame
    for _, pixel_with_indices in pixel_sorted_by_saliency_value_with_indices:
        frame, x, y = pixel_with_indices  # Extract frame (t), x and y
        I_prime = update_image_with_important_pixel(I_prime, original_images, frame, x, y)
        insertion_images.append(I_prime)

    # Calculate predictions on images with gradually added frames
    new_predictions = ensemble_predict(models, insertion_images, x3)

    # Calculate MSE with respect to each new prediction
    errors = [mean_squared_error(original_prediction, masked_pred) for masked_pred in new_predictions[1:]]
    initial_error = mean_squared_error(original_prediction, new_predictions[0])  # MSE of the blurred image
    print(f"Initial Prediction with ALL Blurred pixel, pred: {new_predictions[0]}, error: {initial_error}")

    for t, error in enumerate(errors):
        t_f, t_x, t_y = pixel_sorted_by_saliency_value_with_indices[t][1]  # Extract frame and pixel coordinates
        print(f"Insert pixel in frame {t_f} at position ({t_x},{t_y}), new prediction: {new_predictions[t + 1]}, error: {error}")

    total_errors = [initial_error] + errors
    # Normalize the fraction of inserted pixels
    x = [i / (len(pixel_sorted_by_saliency_value_with_indices) + 1) for i in range(len(total_errors))]

    # Calculate the AUC
    auc = calculate_auc(x, total_errors)
    print(f"Area under the curve (AUC): {auc}")

    # Plot the error curve and area under the curve (AUC)
    plt.plot(x, total_errors, label='Error curve')
    plt.fill_between(x, total_errors, color='skyblue', alpha=0.4)
    plt.text(x[-1] * 0.95, max(total_errors) * 0.9, f'AUC: {auc:.2f}', horizontalalignment='right')
    plt.xlabel('Fraction of pixels inserted')
    plt.ylabel('Mean Squared Error')
    plt.title('Insertion Metric Curve')
    plt.legend()
    plt.show()

    return total_errors, auc


In [None]:
def get_errors_from_insertion(test_image, test_OHE, nr_instance, saliency_video, initial_blurred_images, models):
  original_images = copy.deepcopy(test_image[nr_instance])
  x3 = copy.deepcopy(test_OHE[nr_instance])

  pixel_sorted_by_saliency_value_with_indices = get_flatten_saliency_video_ordered_by_importance(saliency_video)

  errors,auc = insertion(models, original_images, x3, pixel_sorted_by_saliency_value_with_indices, initial_blurred_images)

  return errors,auc

#### ***Deletion***

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
import copy  # Importing the copy module to create deep copies of images

def update_image_by_removing_pixel(image, frame, x, y):
    """
    Update the image by removing the most important pixels.

    :param image: Current image.
    :param frame: Frame number to modify in the image.
    :param x: Spatial coordinate x of the pixel.
    :param y: Spatial coordinate y of the pixel.
    :return: Updated image with the specified pixel removed.
    """
    new_image = copy.deepcopy(image)  # Create a deep copy of the image to avoid modifying the original
    new_image[frame][x,y] =+ np.zeros((1,3))  # Set the pixel at (x, y) to zero (black)

    return new_image

def deletion(models, original_images, x3, saliency_video, pixel_sorted_by_saliency_value_with_indices):
    """
    Calculate the deletion metric for a given explanation.

    :param models: List of pre-trained models.
    :param original_images: Original image.
    :param x3: One-hot encoding for the prediction.
    :param pixel_sorted_by_saliency_value_with_indices: Indices of pixels in order of importance.
    :return: List of errors at each removal step.
    """
    # Get the original prediction from the ensemble of models
    original_prediction = ensemble_predict(models, original_images, x3).flatten()
    print("Original prediction:", original_prediction)

    # List to store images after gradually removing frames
    deletions_images = []

    # Initialize the modified image as a copy of the original
    I_prime = original_images.copy()

    # Gradually remove the most important pixels for each frame
    for _, pixel_with_indices in pixel_sorted_by_saliency_value_with_indices:
        frame, x, y = pixel_with_indices  # Extract frame (t), x, and y coordinates

        I_prime = update_image_by_removing_pixel(I_prime, frame, x, y)  # Update the image by removing the pixel
        deletions_images.append(I_prime)  # Append the modified image to the list

    # Calculate predictions for all modified images with pixels gradually removed
    new_predictions = ensemble_predict(models, deletions_images, x3)
    # Calculate the mean squared error (MSE) compared to the original prediction
    errors = [mean_squared_error(original_prediction, masked_pred) for masked_pred in new_predictions]

    initial_error = 0.0  # Initialize the initial error as zero (no error)
    print(f"Initial Prediction with Original Images, prediction: {original_prediction}, error: {initial_error}")
    for t, error in enumerate(errors):
        t_f = pixel_sorted_by_saliency_value_with_indices[t][1][0]  # Frame number of the pixel removed
        t_x = pixel_sorted_by_saliency_value_with_indices[t][1][1]  # X coordinate of the pixel removed
        t_y = pixel_sorted_by_saliency_value_with_indices[t][1][2]  # Y coordinate of the pixel removed
        print(f"Remove pixel in frame {t_f} in pos ({t_x},{t_y}), new prediction: {new_predictions[t]}, error: {error}")

    total_errors = [initial_error] + errors  # Initial error + errors from all removed pixels

    x = [i/4160 for i in range(0, 4161)]  # Create an array representing the fraction of pixels removed

    # Calculate the area under the curve (AUC) for the error values
    auc = calculate_auc(x, total_errors)
    print(f"Area under the curve (AUC): {auc}")

    # Plot the error curve and fill the area under the curve (AUC)
    plt.plot(x, total_errors, label='Error curve')
    plt.fill_between(x, total_errors, color='lightcoral', alpha=0.4)
    # Position the AUC text to the right of the title
    plt.text(1.02, 1.02, f'AUC: {auc:.2f}',
         horizontalalignment='left',
         transform=plt.gca().transAxes,  # Coordinates relative to the axes (from 0 to 1)
         fontsize=11)
    plt.xlabel('Fraction of pixels removed')  # Label for the x-axis
    plt.ylabel('Mean Squared Error')  # Label for the y-axis
    plt.title('Deletion Metric Curve')  # Title of the plot
    plt.legend()  # Show the legend
    plt.show()  # Display the plot

    return total_errors, auc  # Return the total errors and the AUC value


In [None]:
def get_errors_from_deletion(test_image, test_OHE, nr_instance, saliency_video, models):
  original_images = copy.deepcopy(test_image[nr_instance])
  x3 = copy.deepcopy(test_OHE[nr_instance])

  pixel_sorted_by_saliency_value_with_indices = get_flatten_saliency_video_ordered_by_importance(saliency_video)

  errors,auc = deletion(models, original_images, x3, saliency_video, pixel_sorted_by_saliency_value_with_indices)

  return errors,auc