In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
from scipy.ndimage import zoom

def normalize_image(image):
    image_min = image.min()
    image_max = image.max()
    return (image - image_min) / (image_max - image_min + 1e-8)

def overlay_localized_heatmap_on_slices(volume, heatmap_volume, view='axial', n_rows=8, n_cols=8, threshold=0.6, alpha=0.3):
    if view == 'axial':
        slices = volume
        heatmap_slices = heatmap_volume
    elif view == 'coronal':
        slices = np.transpose(volume, (1, 0, 2))
        heatmap_slices = np.transpose(heatmap_volume, (1, 0, 2))
    elif view == 'sagittal':
        slices = np.transpose(volume, (2, 0, 1))
        heatmap_slices = np.transpose(heatmap_volume, (2, 0, 1))
    else:
        raise ValueError("Invalid view. Choose from 'axial', 'coronal', or 'sagittal'.")

    num_slices = slices.shape[0]
    total_plots = n_rows * n_cols
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(10, 10))
    axes = axes.flatten()

    for idx in range(total_plots):
        slice_idx = idx  # Customizable slice index selection
        if slice_idx >= num_slices:
            axes[idx].axis('off')
            continue
        image_slice = slices[slice_idx, :, :]
        image_slice_norm = normalize_image(image_slice)

        # Normalize 
        heatmap_slice = heatmap_slices[slice_idx, :, :]
        heatmap_slice_norm = normalize_image(heatmap_slice)
        mask = heatmap_slice_norm > threshold  # Boolean mask? 
        image_rgb = np.dstack([image_slice_norm] * 3)  # Convert grayscale to 3-channel (R, G, B)
        heatmap_overlay = np.zeros((*heatmap_slice.shape, 4))  # (H, W, 4) for RGBA
        heatmap_overlay[mask, 0] = 1  # Red channel at max
        heatmap_overlay[mask, 1] = 0  # No green
        heatmap_overlay[mask, 2] = 0  # No blue
        heatmap_overlay[mask, 3] = alpha  #

        axes[idx].imshow(image_rgb, cmap='gray')
        axes[idx].imshow(heatmap_overlay)

        axes[idx].set_title(f'{view.capitalize()} Slice {slice_idx}')
        axes[idx].axis('off')

    plt.tight_layout()
    plt.show()

# Define paths
image_data_path = '/jet/home/abradsha/MEDSYN/results/img_64_standard_bulk/temp/consolidation_sample_0.npy'
heatmap_data_path = '/jet/home/abradsha/MEDSYN/results/img_64_standard_bulk/temp/consolidation_sample_0.npy_token_0_[CLS]_heatmaps.npy'

# Load the generated images
data = np.load(image_data_path)
print(f"Loaded generated images with shape: {data.shape}")

# Average over channels if multiple channels exist
sample = data[0].mean(axis=0)  # [frames, height, width]
num_frames, height, width = sample.shape
print(f"Sample shape: {sample.shape}")

# Load heatmaps
heatmaps = np.load(heatmap_data_path)
print(f"Loaded CLS heatmaps with shape: {heatmaps.shape}")

# Resize heatmaps if dimensions don't match
if heatmaps.shape[1] != height or heatmaps.shape[2] != width:
    resized_heatmaps = []
    for i in range(num_frames):
        heatmap_slice = heatmaps[i, :, :]
        heatmap_slice_norm = normalize_image(heatmap_slice)
        zoom_factors = (height / heatmap_slice.shape[0], width / heatmap_slice.shape[1])
        heatmap_resized = zoom(heatmap_slice_norm, zoom_factors, order=1)  # Bilinear interpolation
        resized_heatmaps.append(heatmap_resized)
    heatmap_volume = np.stack(resized_heatmaps, axis=0)  # (frames, height, width)
else:
    heatmap_volume = heatmaps

print(f"Final heatmap volume shape: {heatmap_volume.shape}")
view = 'axial'

plt.figure(figsize=(15, 15))
plt.suptitle(f"Localized Heatmap Overlay - View: {view.capitalize()}", fontsize=16)

# Set threshold for attention weight visibility
overlay_localized_heatmap_on_slices(sample, heatmap_volume, view=view, n_rows=8, n_cols=8, threshold=0.6, alpha=0.3)
