In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

def plot_slices(imgs_dict, save_path=None, show = True, b_idx=0, c_idx=0, d_idx=None):
    """
    Plots slices of provided 5D torch tensors in a subplot, saves figure as an image to provided path.
    
    Args:
      imgs_dict (dict): Dictionary with keys as the names to show in the figure and values as the 5D torch tensors to be sliced.
                        Expected shape of tensor: (B, C, X, Y, D).
      save_path (str, optional): Path to save the produced image (must include .png ending).
      show (bool, optional): Bool to indicate whether to show plots.
      b_idx (int, optional): Index for the batch slice. Default is 0.
      c_idx (int, optional): Index for the channel slice. Default is 0.
      d_idx (int, optional): Index for the depth (D) slice. If not provided, will use the middle slice.
    """
    # Determine the number of images to plot
    num_images = len(imgs_dict)
    
    if num_images == 0:
        raise ValueError("No images provided for plotting.")
    
    # Calculate subplot layout based on the number of images
    num_cols = min(3, num_images)  # Maximum 3 columns
    num_rows = (num_images + num_cols - 1) // num_cols  # Calculate rows needed based on number of columns
    
    # Create the figure and axes for subplots
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(5 * num_cols, 5 * num_rows))
    axes = axes.flatten()  # Flatten the axes for easier indexing if it's a 2D array
    
    # Loop through the images and plot them
    for idx, (title, img_tensor) in enumerate(imgs_dict.items()):
        # Check that the tensor has the expected shape (B, C, X, Y, D)
        if img_tensor.dim() != 5:
            raise ValueError(f"Expected a 5D tensor for {title}, but got {img_tensor.dim()}D tensor.")
        
        # Determine slice to show (batch b_idx, channel c_idx, depth d_idx)
        B, C, X, Y, D = img_tensor.shape
        if d_idx is None:
            d_idx = D // 2  # If no depth index is provided, use the middle slice
        
        # Extract the 2D slice to display
        slice_img = img_tensor[b_idx, c_idx, :, :, d_idx].detach().cpu().numpy()

        # Plot the slice
        ax = axes[idx]
        ax.imshow(slice_img, cmap='gray')
        ax.set_title(title)
        ax.axis('off')  # Turn off axis ticks
    
    # Hide any unused subplots
    for idx in range(num_images, len(axes)):
        axes[idx].axis('off')
    
    # Adjust layout for tight fit
    fig.tight_layout()
    
    # Save the figure if a save_path is provided
    if save_path:
        plt.savefig(save_path)
        print(f"Figure saved to {save_path}")
    
    # Display the plot
    plt.show(show)
