In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display
import ipywidgets as widgets
import os
import time
import pickle


In [2]:
def plot_in_notebook(array, num_slices=10, save_path=None, custom_name=None, speed_factor=2, show_in_notebook=False):
    """
    Plot a 3D numpy array as animated slices with average of specified number of slices.
    The animations are 30% smaller and use grayscale, and only the shortest axis is plotted.
    
    Parameters:
    - array: 3D numpy array to plot
    - num_slices: Number of slices to average in each frame
    - save_path: Directory path to save the animations as videos. If None, videos are not saved.
    - custom_name: Custom name to be included in the saved video file name and the plot title.
    - speed_factor: A multiplier to control the speed of the animation (higher value means faster animation).
    - show_in_notebook: Boolean flag to control whether the animation is displayed in the notebook or not.
    """
    
    # Determine the shortest axis
    shortest_axis = np.argmin(array.shape)
    
    def update_plot(frame, axis, ax):
        ax.cla()  # Clear the current plot
        
        
        if axis == 0:
            start_slice = max(0, frame - num_slices // 2)
            end_slice = min(array.shape[axis], frame + num_slices // 2 + 1)
            slices = array[start_slice:end_slice, :, :]
            slice_ = np.mean(slices, axis=0)
            xlabel, ylabel = 'x', 'y'
        elif axis == 1:
            start_slice = max(0, frame - num_slices // 2)
            end_slice = min(array.shape[axis], frame + num_slices // 2 + 1)
            slices = array[:, start_slice:end_slice, :]
            slice_ = np.mean(slices, axis=1)
            xlabel, ylabel = 'z', 'y'
        elif axis == 2:
            start_slice = max(0, frame - num_slices // 2)
            end_slice = min(array.shape[axis], frame + num_slices // 2 + 1)
            slices = array[:, :, start_slice:end_slice]
            slice_ = np.mean(slices, axis=2)
            xlabel, ylabel = 'z', 'x'
        
        im = ax.imshow(slice_, cmap='gray', animated=True, origin='lower', vmin=0, vmax=1)
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        
        # Set the plot title using custom_name if provided, else use a default title
        plot_title = f'{custom_name}' if custom_name else f'Axis {axis} ({"z" if axis == 0 else "x" if axis == 1 else "y"}) Slice {frame}'
        ax.set_title(plot_title)
        
        return [im]

    # Plot only along the shortest axis
    fig, ax = plt.subplots(figsize=(6, 4))
    
    # Calculate the interval in milliseconds, with speed_factor adjusting the speed
    interval = max(10, int(100 / speed_factor))  # 100ms is the default, and higher speed_factor means shorter interval
    
    ani = FuncAnimation(
        fig, update_plot, frames=array.shape[shortest_axis], 
        fargs=(shortest_axis, ax), blit=True, repeat=False, interval=interval
    )
    plt.close(fig)  # Close the figure to prevent it from displaying statically

    # Save the animation if save_path is provided
    if save_path is not None:
        # Ensure the save_path directory exists
        os.makedirs(save_path, exist_ok=True)
        axis_names = ['z', 'x', 'y']
        
        # Set the custom name or default to 'animation_axis'
        base_name = f'animation_axis_{axis_names[shortest_axis]}'
        if custom_name:
            base_name += f'_{custom_name}'
        
        video_filename = os.path.join(save_path, f'{base_name}.mp4')
        ani.save(video_filename, writer='ffmpeg', dpi=100)
        print(f"Saved animation for axis {axis_names[shortest_axis]} to {video_filename}")

    # Optionally display the animation in the notebook
    if show_in_notebook:
        output = widgets.Output()
        with output:
            display(HTML(ani.to_jshtml()))
        display(output)


In [3]:
def convert_mask_to_numpy(mask_dict):
    """
    Convert nested mask dictionary to numpy array, removing the unnecessary '1' key level.
    
    Parameters:
    -----------
    mask_dict : dict
        Nested dictionary with format {frame: {1: array}}
    
    Returns:
    --------
    numpy.ndarray
        3D array with shape (frames, height, width)
    """
    # Get all frame numbers and sort them
    frame_numbers = sorted(mask_dict.keys())
    
    # Get the shape of a single frame (removing the '1' key)
    sample_shape = mask_dict[frame_numbers[0]][1].shape[1:]  # Skip the first dimension
    
    # Create empty array with correct shape
    num_frames = len(frame_numbers)
    mask_array = np.zeros((num_frames, *sample_shape), dtype=bool)
    
    # Fill the array, removing the extra dimension
    for i, frame in enumerate(frame_numbers):
        mask_array[i] = mask_dict[frame][1][0]  # [0] removes the extra dimension
    
    return mask_array


In [4]:
# Example usage in Jupyter Lab:
# Load and convert the mask
with open('/home/matiasgp/Desktop/ABLA/results/targ65_full.pkl', 'rb') as f:
    seg_results = pickle.load(f)

mask_array = convert_mask_to_numpy(seg_results['mask'])

# Plot the mask
plot_in_notebook(
    array=mask_array,
    num_slices=20,  # Average 5 slices per frame
    custom_name='mask_visualization',
    speed_factor=4,  # Normal speed
    show_in_notebook=True
)
plot_in_notebook(
    array=seg_results["tomo"],
    num_slices=20,  # Average 5 slices per frame
    custom_name='tomo_visualization',
    speed_factor=4,  # Normal speed
    show_in_notebook=True
)



FileNotFoundError: [Errno 2] No such file or directory: '/home/matiasgp/Desktop/ABLA/results/targ65_full.pkl'