In [2]:
import pyvista as pv
import os

custom_palette = {
    'roze': '#eb8fd8',
    'groen': '#b9d4b4',
    'paars': '#ba94e9',
    'blue': '#4C8BE2',
    'orange': '#E27A3F',
    'grey_light': '#1F3240',
    'grey_dark': '#16242F'
}

# Functions

In [3]:
def render_3d_segmentation(pred_volume, gt_volume, volume_id, output_dir):
    # pred_volume: Tensor of shape (K, D, H, W)
    # gt_volume: Tensor of shape (K, D, H, W)

    # Convert to numpy arrays and get label volumes
    pred_volume_np = torch.argmax(pred_volume, dim=0).cpu().numpy()  # Shape: (D, H, W)
    gt_volume_np = torch.argmax(gt_volume, dim=0).cpu().numpy()

    # Transpose the volumes to match PyVista's (X, Y, Z) format
    pred_volume_np = np.transpose(pred_volume_np, (2, 1, 0))  # Now shape is (W, H, D)
    gt_volume_np = np.transpose(gt_volume_np, (2, 1, 0))

    # Create the grids
    grid_pred = pv.UniformGrid()
    grid_pred.dimensions = pred_volume_np.shape
    grid_pred.spacing = (1, 1, 1)
    grid_pred.origin = (0, 0, 0)

    grid_gt = pv.UniformGrid()
    grid_gt.dimensions = gt_volume_np.shape
    grid_gt.spacing = (1, 1, 1)
    grid_gt.origin = (0, 0, 0)

    # Assign the label data to 'point_data'
    grid_pred.point_data["labels"] = pred_volume_np.flatten(order="F")
    grid_gt.point_data["labels"] = gt_volume_np.flatten(order="F")

    # Define class names and colors (adjust as needed)
    class_names = ['Background', 'Esophagus', 'Heart', 'Trachea', 'Aorta']
    class_colors = [
        'gray',  # Background
        custom_palette['roze'],  # Esophagus
        custom_palette['groen'],  # Heart
        custom_palette['paars'],  # Trachea
        custom_palette['blue']  # Aorta
    ]

    # Create the plotter
    p = pv.Plotter(shape=(1, 2), window_size=(1600, 800), off_screen=True)

    # Function to add class surfaces
    def add_class_surfaces(grid, subplot_index, title):
        p.subplot(0, subplot_index)
        for c in range(1, len(class_names)):  # Skip background if desired
            class_label = c
            # Threshold the grid to extract the class
            class_grid = grid.threshold(value=(class_label - 0.1, class_label + 0.1), scalars='labels')
            if class_grid.n_points == 0:
                continue  # Skip if no points for this class
            # Extract the surface mesh
            surface = class_grid.contour(isosurfaces=[class_label], scalars='labels')
            # Add the mesh to the plotter
            p.add_mesh(surface, color=class_colors[c], opacity=0.6, label=class_names[c])
        p.add_legend(bcolor='white')
        p.add_axes()
        p.set_background('white')
        p.add_title(title)

    # Add ground truth surfaces
    add_class_surfaces(grid_gt, subplot_index=0, title='Ground Truth')

    # Add prediction surfaces
    add_class_surfaces(grid_pred, subplot_index=1, title='Prediction')

    # Link the views
    p.link_views()

    # Define the camera positions you want to capture
    camera_positions = {
        'isometric': 'iso',
        'front': 'xz',
        'side': 'yz',
        'top': 'xy',
    }

    # Save a screenshot for each camera position
    for view_name, camera_pos in camera_positions.items():
        # Set the camera position
        p.camera_position = camera_pos
        # Update rendering
        p.render()
        # Save the rendering to a file
        screenshot_path = os.path.join(output_dir, f'Patient_{volume_id}_{view_name}_view.png')
        p.screenshot(screenshot_path)
        print(f"Saved {view_name} view to {screenshot_path}")

    # Close the plotter
    p.close()
    print(f"Rendered 3D segmentation saved to {screenshot_path}")

In [4]:
import nibabel as nib
import numpy as np
import torch
from scipy.ndimage import zoom

def load_nifti_as_tensor(nifti_path, num_classes):
    """
    Loads a NIfTI file and converts it into a one-hot encoded PyTorch tensor.
    Resamples the data to have isotropic voxel sizes to correct aspect ratio distortion.

    Args:
        nifti_path (str): Path to the NIfTI file (.nii or .nii.gz).
        num_classes (int): Total number of classes (including background).

    Returns:
        tensor_volume (torch.Tensor): One-hot encoded tensor of shape (K, D, H, W).
    """
    # Load the NIfTI file
    nifti_img = nib.load(nifti_path)
    volume_data = nifti_img.get_fdata()

    # Print some information
    print(f"Loaded NIfTI file from {nifti_path}")
    print(f"Original volume shape: {volume_data.shape}")
    print(f"Data type: {volume_data.dtype}")

    # Reorient the image to RAS+ (standard orientation)
    canonical_img = nib.as_closest_canonical(nifti_img)
    volume_data = canonical_img.get_fdata()

    # Get voxel sizes from the canonical image
    voxel_sizes = canonical_img.header.get_zooms()
    print(f"Voxel sizes (after reorientation): {voxel_sizes}")

    # Get axes codes after reorientation
    aff = canonical_img.affine
    axes_codes = nib.orientations.aff2axcodes(aff)
    print(f"Axes codes after reorientation: {axes_codes}")

    # Transpose the data to match the expected (D, H, W) shape
    # After reorientation, the data is in (Z, Y, X) order
    volume_data = np.transpose(volume_data, (2, 1, 0))  # Now shape is (D, H, W)
    print(f"Volume shape after transpose: {volume_data.shape}")

    # Resample the data to isotropic voxel sizes
    # Determine the scaling factors to achieve isotropic voxel sizes
    # We'll use the smallest voxel size among x, y, and z as the target voxel size
    target_voxel_size = min(voxel_sizes)
    scaling_factors = (
        voxel_sizes[2] / target_voxel_size,  # scaling factor for D (Z-axis)
        voxel_sizes[1] / target_voxel_size,  # scaling factor for H (Y-axis)
        voxel_sizes[0] / target_voxel_size   # scaling factor for W (X-axis)
    )
    print(f"Scaling factors: {scaling_factors}")

    # Since we're dealing with labels, use nearest-neighbor interpolation
    # to resample the data without introducing new labels
    volume_data_resampled = zoom(volume_data, zoom=scaling_factors, order=0)
    print(f"Volume shape after resampling: {volume_data_resampled.shape}")

    # Convert volume data to integer type (assuming labels are integers)
    volume_data_resampled = volume_data_resampled.astype(np.int32)

    # Determine the unique labels in the resampled volume
    unique_labels = np.unique(volume_data_resampled)
    print(f"Unique labels in the resampled volume: {unique_labels}")

    # Create one-hot encoded tensor
    K = num_classes
    D, H, W = volume_data_resampled.shape
    tensor_volume = torch.zeros((K, D, H, W), dtype=torch.float32)

    # Fill the tensor with one-hot encoding
    for k in range(num_classes):
        tensor_volume[k] = torch.from_numpy((volume_data_resampled == k).astype(np.float32))

    return tensor_volume


# Cell that runs the code

In [None]:
gt_nifti_path = '../../data/segthor_original/train/Patient_01/GT.nii.gz'
pred_nifti_path = '../../data/segthor_original/train/Patient_02/GT.nii.gz'
num_classes = 5  # Adjust based on your dataset (including background)

# Load and process the ground truth volume
gt_volume = load_nifti_as_tensor(gt_nifti_path, num_classes)

# Load and process the prediction volume
pred_volume = load_nifti_as_tensor(pred_nifti_path, num_classes)

# Now you can use gt_volume and pred_volume in your plotting functions
volume_id = 'volume_identifier'  # Adjust as needed
output_dir = '.'
os.makedirs(output_dir, exist_ok=True)
render_3d_segmentation(pred_volume, gt_volume, volume_id, output_dir='.')


Loaded NIfTI file from ../../data/segthor_original/train/Patient_01/GT.nii.gz
Original volume shape: (512, 512, 229)
Data type: float64
Voxel sizes (after reorientation): (0.9765625, 0.9765625, 2.0)
Axes codes after reorientation: ('R', 'A', 'S')
Volume shape after transpose: (229, 512, 512)
Scaling factors: (2.048, 1.0, 1.0)
Volume shape after resampling: (469, 512, 512)
Unique labels in the resampled volume: [0 1 2 3 4]
Loaded NIfTI file from ../../data/segthor_original/train/Patient_02/GT.nii.gz
Original volume shape: (512, 512, 246)
Data type: float64
Voxel sizes (after reorientation): (0.976562, 0.976562, 2.5)
Axes codes after reorientation: ('R', 'A', 'S')
Volume shape after transpose: (246, 512, 512)
Scaling factors: (2.5600011, 1.0, 1.0)
Volume shape after resampling: (630, 512, 512)
Unique labels in the resampled volume: [0 1 2 3 4]




# Some deprecated stuff

In [ ]:
# THIS IS ONLY USED FOR PICKLE OBJECTS OF THE PREDICTIONS IN VOLUME

# def animate_3d_volume(volume_predictions, volume_ground_truths, evaluate_dir):
#     # Assuming volume_predictions and volume_ground_truths are defined as before
#     ids = [1, 13, 22, 28, 30]
#     for volume_id in ids:
#         # Retrieve the list of slices for the given volume_id
#         pred_slices = volume_predictions[volume_id]  # List of (slice_idx, pred_slice)
#         gt_slices = volume_ground_truths[volume_id]  # List of (slice_idx, gt_slice)
# 
#         # Sort the slices by slice_idx
#         pred_slices_sorted = sorted(pred_slices, key=lambda x: x[0])
#         gt_slices_sorted = sorted(gt_slices, key=lambda x: x[0])
# 
#         # Stack the slices into a Tensor volume
#         pred_volume = torch.stack([slice_data for idx, slice_data in pred_slices_sorted], dim=1)
#         gt_volume = torch.stack([slice_data for idx, slice_data in gt_slices_sorted], dim=1)
# 
#         # Now pred_volume and gt_volume are Tensors of shape (K, D, H, W)
#         # where K is the number of classes, D is the depth (number of slices)
# 
#         # Call the rendering function
#         # render_3d_segmentation(pred_volume, gt_volume, volume_id)
#         render_3d_segmentation(pred_volume, gt_volume, volume_id, output_dir=evaluate_dir)

In [None]:
# ANIMATION CODE

# import torch
# import pyvista as pv
# import numpy as np
# import os
# 
# def render_3d_segmentation_with_animation(pred_volume, gt_volume, volume_id, output_dir):
#     # pred_volume and gt_volume: Tensors of shape (K, D, H, W)
# 
#     # Convert to numpy arrays and get label volumes
#     pred_volume_np = torch.argmax(pred_volume, dim=0).cpu().numpy()  # Shape: (D, H, W)
#     gt_volume_np = torch.argmax(gt_volume, dim=0).cpu().numpy()
# 
#     # Transpose the volumes to match PyVista's (X, Y, Z) format
#     pred_volume_np = np.transpose(pred_volume_np, (2, 1, 0))  # Now shape is (W, H, D)
#     gt_volume_np = np.transpose(gt_volume_np, (2, 1, 0))
# 
#     # Create the grids
#     grid_pred = pv.UniformGrid()
#     grid_pred.dimensions = pred_volume_np.shape
#     grid_pred.spacing = (1, 1, 1)
#     grid_pred.origin = (0, 0, 0)
#     grid_pred.point_data["labels"] = pred_volume_np.flatten(order="F")
# 
#     grid_gt = pv.UniformGrid()
#     grid_gt.dimensions = gt_volume_np.shape
#     grid_gt.spacing = (1, 1, 1)
#     grid_gt.origin = (0, 0, 0)
#     grid_gt.point_data["labels"] = gt_volume_np.flatten(order="F")
# 
#     # Define class names and colors (adjust as needed)
#     class_names = ['Background', 'Esophagus', 'Heart', 'Trachea', 'Aorta']
#     class_colors = ['gray', 'red', 'green', 'blue', 'yellow']
# 
#     # Create the plotter with subplots
#     p = pv.Plotter(shape=(1, 2), off_screen=True, window_size=(1600, 800))
# 
#     # Function to add class surfaces to a subplot
#     def add_class_surfaces(grid, subplot_index, title):
#         p.subplot(0, subplot_index)
#         for c in range(1, len(class_names)):  # Skip background if desired
#             class_label = c
#             # Threshold the grid to extract the class
#             class_grid = grid.threshold(value=(class_label - 0.1, class_label + 0.1), scalars='labels')
#             if class_grid.n_points == 0:
#                 continue  # Skip if no points for this class
#             # Extract the surface mesh
#             surface = class_grid.contour(isosurfaces=[class_label], scalars='labels')
#             # Add the mesh to the plotter
#             p.add_mesh(surface, color=class_colors[c], opacity=0.6, label=class_names[c])
#         p.add_legend(bcolor='white')
#         p.add_axes()
#         p.set_background('white')
#         p.add_title(title)
# 
#     # Add ground truth to the first subplot
#     add_class_surfaces(grid_gt, subplot_index=0, title='Ground Truth')
# 
#     # Add prediction to the second subplot
#     add_class_surfaces(grid_pred, subplot_index=1, title='Prediction')
# 
#     # Link the views so that camera movements affect both subplots
#     p.link_views()
# 
#     # Set the camera position
#     p.camera_position = 'iso'  # Isometric view
# 
#     # Prepare to capture frames
#     num_frames = 60
#     rotation_angle = 360 / num_frames
# 
#     # Directory to save frames
#     frames_dir = os.path.join(output_dir, f'volume_{volume_id}_frames')
#     os.makedirs(frames_dir, exist_ok=True)
# 
#     # Capture frames
#     for i in range(num_frames):
#         # Rotate the camera
#         p.camera.Azimuth(rotation_angle)
#         # Update rendering
#         p.render()
#         # Save frame
#         frame_path = os.path.join(frames_dir, f'frame_{i:03d}.png')
#         p.screenshot(frame_path)
# 
#     # Close the plotter
#     p.close()
# 
#     # Create an animation (e.g., GIF)
#     animation_path = os.path.join(output_dir, f'volume_{volume_id}_animation.gif')
#     import imageio
#     with imageio.get_writer(animation_path, mode='I', duration=0.05) as writer:
#         for i in range(num_frames):
#             frame_path = os.path.join(frames_dir, f'frame_{i:03d}.png')
#             image = imageio.imread(frame_path)
#             writer.append_data(image)
# 
#     print(f"Animation saved to {animation_path}")
# 
#     # Optionally, remove the frames directory to save space
#     import shutil
#     shutil.rmtree(frames_dir)
