In [1]:
import numpy as np
import os
import monai
import pandas as pd
import napari
import matplotlib as plt
from ipywidgets import interact
%matplotlib inline

Matplotlib is building the font cache; this may take a moment.


In [2]:
root_dir = '..'
load_dir = 'data/train_labels.csv'
data_path = os.path.join(root_dir,load_dir)
labels = pd.read_csv(data_path)

In [3]:
def gaussian_heatmap(shape, points, radii):
    """
    Create a 3D Gaussian heatmap.

    Parameters:
    shape (tuple): Shape of the heatmap (D, H, W)
    points (list): List of (x, y, z) coordinates for the Gaussians
    radii (list or float): List of radii or single radius value for all points

    Returns:
    numpy.ndarray: 3D heatmap with Gaussian distributions
    """
    if isinstance(radii, (int, float)):
        radii = [radii] * len(points)

    D, H, W = shape
    z, y, x = np.meshgrid(np.arange(D), np.arange(H), np.arange(W), indexing='ij')
    heatmap = np.zeros(shape, dtype=np.float32)

    for (px, py, pz), radius in zip(points, radii):
        sigma = radius / 3.0
        gaussian = np.exp(-((x - px)**2 + (y - py)**2 + (z - pz)**2) / (2 * sigma**2))
        heatmap = np.maximum(heatmap, gaussian)

    return heatmap

In [None]:
heatmap = gaussian_heatmap((300,959,928), [[169,546,603]], 40 )

In [5]:
subset = heatmap[100:200,500:600,550:650]


In [None]:
from matplotlib.widgets import Slider

def visualize_volume(volume):
    D, H, W = volume.shape

    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    plt.subplots_adjust(bottom=0.25)

    img_xy = axs[0].imshow(volume[D // 2], cmap='hot')
    axs[0].set_title(f"Axial Slice {D // 2}")

    img_xz = axs[1].imshow(volume[:, H // 2, :], cmap='hot')
    axs[1].set_title(f"Coronal Slice {H // 2}")

    img_yz = axs[2].imshow(volume[:, :, W // 2], cmap='hot')
    axs[2].set_title(f"Sagittal Slice {W // 2}")

    ax_slider_xy = plt.axes([0.2, 0.1, 0.2, 0.03])
    slider_xy = Slider(ax_slider_xy, 'Axial', 0, D-1, valinit=D // 2, valstep=1)

    ax_slider_xz = plt.axes([0.45, 0.1, 0.2, 0.03])
    slider_xz = Slider(ax_slider_xz, 'Coronal', 0, H-1, valinit=H // 2, valstep=1)

    ax_slider_yz = plt.axes([0.7, 0.1, 0.2, 0.03])
    slider_yz = Slider(ax_slider_yz, 'Sagittal', 0, W-1, valinit=W // 2, valstep=1)

    def update_xy(val):
        slice_idx = int(slider_xy.val)
        img_xy.set_data(volume[slice_idx])
        axs[0].set_title(f"Axial Slice {slice_idx}")
        fig.canvas.draw_idle()

    def update_xz(val):
        slice_idx = int(slider_xz.val)
        img_xz.set_data(volume[:, slice_idx, :])
        axs[1].set_title(f"Coronal Slice {slice_idx}")
        fig.canvas.draw_idle()

    def update_yz(val):
        slice_idx = int(slider_yz.val)
        img_yz.set_data(volume[:, :, slice_idx])
        axs[2].set_title(f"Sagittal Slice {slice_idx}")
        fig.canvas.draw_idle()

    slider_xy.on_changed(update_xy)
    slider_xz.on_changed(update_xz)
    slider_yz.on_changed(update_yz)

    plt.show()

In [10]:
visualize_volume(subset)

TypeError: 'module' object is not callable