In [None]:
import pandas as pd
import numpy as np
import glob
import torch
import os
import ast
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# KL processing

In [None]:
test_dataset = 'FS_Static'

In [None]:
kl_directories = {
    'FS_Static_C' : './kl_FS_Static_C',
    'FS_Static': './kl_FS_Static',
    'RoadAnomaly': './kl_RoadAnomaly',
}

In [None]:
# Path to the directory where the files are stored
directory_path = kl_directories[test_dataset]

# Initialize an empty dictionary
kl_dict = {}

# List all files in the directory and process each file
for filename in os.listdir(directory_path):
    if filename.endswith('.npy'):
        # Extract the key name between 'kl_' and '.npy'
        start = filename.find('kl_div_') + 7 
        end = filename.find('.npy')
        layer_name = filename[start:end]

        # Load the contents of the .npy file
        file_path = os.path.join(directory_path, filename)
        layer_content = np.load(file_path)

        kl_dict[layer_name] = layer_content



In [None]:
kl_dict['model.module.mod2.block1.bn1.0'].shape

In [None]:
images_idx = np.array(np.linspace(0, 59, 30),dtype=int)
real_idx = np.floor(images_idx/2)
real_idx = np.array(real_idx,dtype=int)

In [None]:
ordered_idx = [0,1,21,25,12,13,14,15,16,17,18,18,2,20,21,22,23,24,25,26,27,28,29,3,4,5,6,7,8,9]

In [None]:
test_images_v1 = []
test_images = []
test_labels = []
for idx in images_idx:
    test_images_v1.append(np.load(f'./input_images_{test_dataset}/image_{idx}.npy').squeeze().transpose((1,2,0)))
for idx in ordered_idx:  
    test_images.append(mpimg.imread(f'../original_images/{test_dataset}/images/image_{idx}.png'))
    test_labels.append(mpimg.imread(f'../original_images/{test_dataset}/labels/image_{idx}.png'))

In [None]:
def plot_images(test_images, cols=3):
    num_images = len(test_images)
    rows = (num_images  // cols)  # Calculate the number of rows needed

    fig, axs = plt.subplots(rows, cols, figsize=(16*cols,9*rows))  # Create a grid of subplots
    axs = axs.flatten()  # Flatten the array of axes, if necessary

    for i, img in enumerate(test_images):
        axs[i].imshow(img, cmap='gray')  # Show image on corresponding subplot
        axs[i].axis('off')  # Turn off axis
        axs[i].set_title(f'Image {i}')

    plt.tight_layout()
    plt.show()

plot_images(test_images, cols=3)

In [None]:
def apply_masks(test_images, test_labels):
    masked_images = np.zeros_like(test_images)
    
    # Iterate over each pair of image and label mask
    for i in range(len(test_images)):
        # Expand the dimensions of the mask to make it (height, width, 3) to match image channels
        mask_expanded = np.expand_dims(test_labels[i], axis=-1)
        # Repeat the mask across the channel dimension
        mask_repeated = np.repeat(mask_expanded, 3, axis=2)
        # Apply the mask by element-wise multiplication
        masked_images[i] = test_images[i] * (mask_repeated > 0)  # mask > 0 if you want to ignore zero values
    
    return masked_images

masked_images = apply_masks(test_images, test_labels)
plot_images(masked_images, cols=3)

In [None]:
# sum the KL div over the channels
kl_dict_sum = {}
kl_dict_sum_select = {}
for key, images_kl in kl_dict.items():
    kl_dict_sum[key] = []
    kl_dict_sum_select[key] = []
    for image_kl in images_kl:
        kl_dict_sum[key].append(np.sum(image_kl, axis=0))
    for idx in images_idx:
        kl_dict_sum_select[key].append(kl_dict_sum[key][idx])

In [None]:
from scipy.ndimage import zoom

# Assuming test_images is a list or array of images, derive the size from the first image
image_size = test_images[0].shape[:-1]
kl_dict_upscaled = {}

# Upscale kl_dict_sum to match the size of each image in test_images
for layer, values in kl_dict_sum_select.items():
    # Convert values to a numpy array if it's not already
    values = np.array(values)
    # Extract the spatial dimensions (the last two or three dimensions, depending on data)
    spatial_dims = values.shape[-len(image_size):]
    # Calculate the zoom factors based on these dimensions
    zoom_factors = [n / o for n, o in zip(image_size, spatial_dims)]
    # Prepare a list to hold the upscaled data for each entry in 'values'
    upscaled = []
    for single_value in values:
        # Apply the zoom operation to each entry
        upscaled.append(zoom(single_value, zoom_factors, order=1))  # Using bilinear interpolation
    kl_dict_upscaled[layer] = upscaled

In [None]:
def create_mosaic_heatmaps(images, heatmaps, layer_name,cols=3):
    # Determine the number of images (assumes at least one image)
    num_images = len(images)
    rows = num_images//cols
    
    # Create a figure with subplots
    fig, axes = plt.subplots(rows, cols, figsize=(2*16*cols/12, 2*9*rows/12))
    
    # Flatten the axes array for easy indexing
    axes = axes.flatten()
    
    # Loop through each image and corresponding heatmap
    for idx, (image, heatmap_data) in enumerate(zip(images, heatmaps)):
        # Normalize heatmap data
        normalized_heatmap = (heatmap_data - np.min(heatmap_data)) / (np.max(heatmap_data) - np.min(heatmap_data))
        
        # Display the image
        axes[idx].imshow(image, cmap='gray', interpolation='nearest', aspect='auto')
        
        # Overlay the heatmap
        axes[idx].imshow(normalized_heatmap, cmap='hot', alpha=0.75, interpolation='nearest', aspect='auto')
        
        # Turn off axis labels
        axes[idx].axis('off')
    
    # Set the main title
    plt.suptitle(f"Heatmaps for {layer_name}", fontsize=16)
    
    # Hide any unused subplots if there are any
    for ax in axes[num_images:]:
        ax.axis('off')

    plt.tight_layout()
    
    # Save the figure
    plt.savefig(f'./heatmaps_{test_dataset}/mosaic_{layer_name}.png', bbox_inches='tight')
    
    # Display the plot
    plt.show()

# Loop through each layer and create a mosaic
for layer_name, heatmaps in kl_dict_upscaled.items():
    create_mosaic_heatmaps(test_images, heatmaps, layer_name)

In [None]:
def create_mosaic_heatmaps_individual(images, heatmaps, layer_name,id_img,cols=1):

    num_images = len(images)
    fig, axes = plt.subplots(1, 1, figsize=(16,9))
    
    # Loop through each image and corresponding heatmap
    for idx, (image, heatmap_data) in enumerate(zip(images, heatmaps)):
        # Normalize heatmap data
        normalized_heatmap = (heatmap_data - np.min(heatmap_data)) / (np.max(heatmap_data) - np.min(heatmap_data))
        
        # Display the image
        axes[idx].imshow(image, cmap='gray', interpolation='nearest', aspect='auto')
        
        # Overlay the heatmap
        axes[idx].imshow(normalized_heatmap, cmap='hot', alpha=0.75, interpolation='nearest', aspect='auto')
        
        # Turn off axis labels
        axes[idx].axis('off')
    
    # Set the main title
    plt.suptitle(f"Heatmaps for {layer_name}", fontsize=16)
    
    # Hide any unused subplots if there are any
    for ax in axes[num_images:]:
        ax.axis('off')

    plt.tight_layout()
    
    # Construct the directory path
    directory = f'./heatmaps_{test_dataset}_individual/image{id_img}'

    # Ensure the directory exists
    os.makedirs(directory, exist_ok=True)

    # Save the file
    plt.savefig(f'{directory}/mosaic_{layer_name}.png', bbox_inches='tight')
    
    # Display the plot
    #plt.show()

# Loop through each layer and create a mosaic
for layer_name, heatmaps in kl_dict_upscaled.items():
    for idx, image in enumerate(test_images):
        heatmap = heatmaps[idx]
        create_mosaic_heatmap_individual(image, heatmap, layer_name, idx)


In [None]:
from PIL import Image
def apply_mask(images, masks):
    # Convert images to an array if not already, maintaining float dtype for normalized values
    images = np.array(images, dtype=float)
    
    # Ensure masks are boolean arrays
    masks = np.array(masks, dtype=bool)
    
    # Prepare an array to store the masked images
    masked_images = np.zeros_like(images)

    for i in range(len(images)):
        # Apply the mask to all channels via broadcasting
        masked_images[i] = images[i] * masks[i][:, :, None]

        # Scale the masked image to the range 0-255 and convert to uint8
        img_uint8 = (masked_images[i] * 255).astype(np.uint8)

        # Convert the scaled array to an image
        img = Image.fromarray(img_uint8)

        # Display the mask and the resulting masked image for debugging
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(masks[i], cmap='gray')
        plt.title(f'Mask {i}')
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(img)
        plt.title(f'Masked Image {i}')
        plt.axis('off')
        plt.show()

        # Ensure the directory exists
        directory = f'./heatmaps_{test_dataset}_individual/image{i}'
        if not os.path.exists(directory):
            os.makedirs(directory)

        # Save the file
        img.save(f'{directory}/masked_test_image.png')

apply_mask(test_images, test_labels)

In [None]:
for i, image in enumerate(test_images):
    directory = f'./heatmaps_{test_dataset}_individual/image{i}'
    img = Image.fromarray((image*255).astype(np.uint8))
    img.save(f'{directory}/input_image.png')