# **3D Heart Segmentation**: From MRI to Reality

This notebook demonstrates the process of **reconstructing 3D heart models** from volumetric MRI data by applying the **Ball-Pivoting Algorithm (BPA)** for surface mesh generation. It leverages Open3D for mesh processing, including **smoothing, thresholding, and visualization** of anatomical structures to enhance segmentation quality and provide better diagnostic insights.


# Setup & Imports

In [152]:
!pip install -q pyvista exposure SimpleITK open3d trimesh

**Model imports**

In [153]:
# System & OS
import os
import sys
import json
import random
import glob
import re

# Google Drive (Colab)
from google.colab import drive

# Math & Stats
import numpy as np
import pandas as pd
from scipy.stats import f_oneway, kruskal
from scipy.ndimage import gaussian_filter, zoom, shift, label, center_of_mass
from scipy import ndimage

# Deep Learning (PyTorch)
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import torchvision.models.video as models


# Image Processing
import nibabel as nib
import SimpleITK as sitk
import cv2
from skimage.filters import threshold_otsu
from skimage import morphology
from skimage import measure
from skimage import exposure
from skimage.exposure import equalize_hist, equalize_adapthist

# Visualization
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import matplotlib.patches as mpatches
import pyvista as pv
import vtk
import trimesh
import numpy as np
import plotly.figure_factory as ff
from matplotlib import cm

import trimesh
import numpy as np
import plotly.figure_factory as ff
import matplotlib.pyplot as plt

In [154]:
color_palette = ["#FF4C4C",
                 "#FF7373",
                 "#FF9999",
                 "#FFB6C1",
                 "#6600CC",
                 "#9900FF",
                 "#4C00FF",
                 "#0066FF",
                 "#FFFFFF"]

drive.mount('/content/drive')

BASE_PATH = "/content/drive/MyDrive/Projects/Heart Reconstruction/dataset/"
SAMPLE_IMAGE_PATH = os.path.join(BASE_PATH, 'train/images/patient092.nii.gz')
SAMPLE_MASK_PATH = os.path.join(BASE_PATH, 'train/masks/patient092_mask.nii.gz')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# **1. Process Sampled MRI**

#### Processing functions

In [155]:
def detect_and_correct_flip(img):
    qform = img.header.get_qform()
    sform = img.header.get_sform()
    img = img.get_fdata()

    if qform is not None and np.allclose(qform[:3, :3], np.diag([-1, -1, 1])):
        return np.flip(img, axis=(0, 1))  # Flip X and Y axes

    if sform is not None and np.allclose(sform[:3, :3], np.diag([-1, -1, 1])):
        return np.flip(img, axis=(0, 1))  # Flip X and Y axes

    return img

In [219]:
def align_mask(image, mask):
    img_com = center_of_mass(image)
    mask_com = center_of_mass(mask)

    if img_com is None or mask_com is None:
        return mask  # Skip alignment if no valid center found

    shift_vector = np.array(img_com) - np.array(mask_com)
    aligned_mask = shift(mask, shift_vector, order=0)

    return aligned_mask

In [186]:
def log_bias_correction(image, epsilon=1e-6):
    image = (image - image.min()) / (image.max() - image.min() + epsilon)
    corrected = np.log1p(image)
    corrected = (corrected - corrected.min()) / (corrected.max() - corrected.min()) * (image.max() - image.min()) + image.min()

    return corrected.astype(image.dtype)

In [157]:
# Function to denoise image
def denoise_image(img):
    return sitk.CurvatureFlow(image1=img, timeStep=0.125, numberOfIterations=5)

In [160]:
def scale_z_for_visualization(volume, voxel_spacing):
    scale_factors = [1, 1, voxel_spacing[2] / np.mean(voxel_spacing[:2])]
    print(f"Scaling factors applied: {scale_factors}")

    scaled_volume = ndimage.zoom(volume, zoom=scale_factors, order=1)  # Linear interpolation for smooth scaling
    return scaled_volume

In [249]:
def enhance_heart_structure(image, mask):
    # Convert image to SimpleITK format
    sitk_image = sitk.GetImageFromArray(image)

    # Compute Gradient Magnitude (better than just Canny edges)
    gradient_magnitude = sitk.GradientMagnitude(sitk_image)
    gradient_magnitude = sitk.GetArrayFromImage(gradient_magnitude)

    # Normalize to [0,1] for scaling
    gradient_magnitude = (gradient_magnitude - gradient_magnitude.min()) / (gradient_magnitude.max() - gradient_magnitude.min())

    # Enhance image: Increase intensity in regions with high gradient magnitude
    enhanced_image = image.copy()
    enhanced_image += 0.5 * gradient_magnitude  # Amplify edges by a factor

    # Enhance mask boundaries: Expand segmentation where gradient is high
    dilated_mask = ndimage.binary_dilation(mask > 0, iterations=2)
    enhanced_mask = np.logical_or(mask > 0, np.logical_and(gradient_magnitude > 0.2, dilated_mask))

    return enhanced_image, enhanced_mask.astype(mask.dtype)


In [274]:
def normalize_intensity(image):
    return (image - np.min(image)) / (np.max(image) - np.min(image) + 1e-8)

In [273]:
def remove_low_intensity(image, threshold=0.01):
    image[image < threshold] = 0  # Suppress low-intensity values
    return image

In [276]:
def process(image, mask):
    voxel_spacing = image.header.get_zooms()

    image = detect_and_correct_flip(image)
    mask = detect_and_correct_flip(mask)

    if isinstance(image, nib.Nifti1Image):
        image = image.get_fdata()  # Convert to NumPy array
    if isinstance(mask, nib.Nifti1Image):
        mask = mask.get_fdata()  # Convert to NumPy array

    image = log_bias_correction(image)
    image = remove_low_intensity(image, threshold=0.01)
    image = normalize_intensity(image)
    image, mask = enhance_heart_structure(image, mask)

    image = scale_z_for_visualization(image, voxel_spacing)
    mask = scale_z_for_visualization(mask, voxel_spacing)

    # Align mask
    mask = align_mask(image, mask)

    # Convert to tensors
    image = torch.tensor(image, dtype=torch.float32).unsqueeze(0)  # Add channel dim
    mask = torch.tensor(mask, dtype=torch.float32).unsqueeze(0)

    return image, mask

<a id="3d-visualization"></a>
# **2. Visualizing the Heart Model**



#### **Visualization functions**

In [164]:
def visualize_clusters(high_intensity_coords, labels):
    unique_labels = np.unique(labels)

    fig = go.Figure()

    for label in unique_labels:
        if label == -1:  # Skip noise points
            continue

        cluster_points = high_intensity_coords[labels == label]
        fig.add_trace(go.Scatter3d(
            x=cluster_points[:, 0],
            y=cluster_points[:, 1],
            z=cluster_points[:, 2],
            mode="markers",
            marker=dict(size=3, opacity=0.6),
            name=f"Cluster {label}"
        ))

    fig.update_layout(
        height=600, width=600,
        title="3D High-Intensity Clusters",
        scene=dict(xaxis_title="X", yaxis_title="Y", zaxis_title="Z"),
    )

    fig.show()

In [165]:
def plot_mri_slice(image, mask, slice_idx=None, cmap=["gray", "Reds"], title=None, alpha=None):
    if slice_idx is None:
        slice_idx = image.shape[-1] // 2  # Choose the middle slice

    fig, ax = plt.subplots(1, 2, figsize=(7, 4))
    ax[0].imshow(image[:, :, slice_idx], cmap=cmap[0], alpha=alpha)
    ax[0].set_title("Before")
    ax[1].imshow(mask[:, :, slice_idx], cmap=cmap[1], alpha=alpha)
    ax[1].set_title("After")
    fig.suptitle(title, fontsize=16, weight="bold")  # Reduced pad for less space
    fig.tight_layout(rect=[0, 0.01, 1, 1])  # Spacing between subplots
    plt.subplots_adjust(wspace=0.3)
    plt.show()


Main steps:
* **Segmentation refinement** – Focused on key structures (LV, RV, Myocardium) with morphological closing.
* **3D Mesh Generation** – Applied Marching Cubes to reconstruct the heart model.
* **Surface Smoothing** – Reduced MRI artifacts using Gaussian filtering for clarity.


##### **Enhancing volume**

In [166]:
def apply_mask_to_volume(volume, mask):
    if torch.is_tensor(volume):
        volume = volume.cpu().numpy()
    if torch.is_tensor(mask):
        mask = mask.cpu().numpy()

    volume = (volume - volume.min()) / (volume.max() - volume.min())
    masked_volume = np.zeros_like(volume)
    masked_volume[mask > 0] = volume[mask > 0]

    return masked_volume

##### **3D Reconstraction**

In [167]:
def create_slice_views(volume):
    if isinstance(volume, torch.Tensor):
        volume = volume.cpu().numpy()

    x_mid = volume.shape[0] // 2
    y_mid = volume.shape[1] // 2
    z_mid = volume.shape[2] // 2

    fig = make_subplots(
        rows=1, cols=3,
        specs=[[{'type': 'heatmap'}, {'type': 'heatmap'}, {'type': 'heatmap'}]],
        subplot_titles=('Axial View', 'Sagittal View', 'Coronal View')
    )

    # Add three views of the volume with their own colorscales
    fig.add_trace(go.Heatmap(z=volume[:, :, z_mid], colorscale='gray',
                              showscale=False, name='Axial'), row=1, col=1)
    fig.add_trace(go.Heatmap(z=volume[:, y_mid, :], colorscale='gray',
                              showscale=False, name='Sagittal'), row=1, col=2)
    fig.add_trace(go.Heatmap(z=volume[x_mid, :, :].T, colorscale='gray',
                              showscale=False, name='Coronal'), row=1, col=3)

    fig.update_layout(
          height=400,
          width=1200,
          showlegend=False,
          title_text="MRI Visualization",
          font=dict(size=20),
          title_x=0.5
      )

    return fig

In [168]:
def show_figures_side_by_side(*figs, titles=None, headline=None):
    from plotly.subplots import make_subplots

    n_figs = len(figs)
    if n_figs == 0:
        raise ValueError("At least one figure must be provided")

    if titles is not None and len(titles) != n_figs:
        raise ValueError(f"Number of titles ({len(titles)}) must match number of figures ({n_figs})")

    specs = [[{'type': 'scene'} for _ in range(n_figs)]]

    combined_fig = make_subplots(
        rows=1, cols=n_figs,
        subplot_titles=titles,
        specs=specs
    )

    for i, fig in enumerate(figs, 1):
        for trace in fig.data:
            combined_fig.add_trace(trace, row=1, col=i)

    scene_updates = {}
    for i in range(1, n_figs + 1):
        scene_key = f'scene{i}' if i > 1 else 'scene'
        scene_updates[scene_key] = dict(
            xaxis=dict(showticklabels=False, showgrid=False, zeroline=False),
            yaxis=dict(showticklabels=False, showgrid=False, zeroline=False),
            zaxis=dict(showticklabels=False, showgrid=False, zeroline=False),
            aspectmode='cube',
            camera=dict(eye=dict(x=1, y=1, z=1))
        )

    width = min(600 * n_figs, 1600)

    combined_fig.update_layout(
        **scene_updates,
        height=500,
        width=width,
        showlegend=False,
        title_text=f"<b>{headline}</b>" if headline else None,
        title_x=0.5,
        font=dict(size=20)
    )

    return combined_fig

In [169]:
def compute_gradient_magnitude(image):
    sitk_image = sitk.GetImageFromArray(image)
    gradient_magnitude = sitk.GradientMagnitude(sitk_image)
    gradient_magnitude = sitk.GetArrayFromImage(gradient_magnitude)

    # Normalize gradient for visualization
    gradient_magnitude = (gradient_magnitude - gradient_magnitude.min()) / (gradient_magnitude.max() - gradient_magnitude.min())
    return gradient_magnitude

In [170]:
import numpy as np
import trimesh

def compute_vertex_gradients(mesh_path, volume, gradient_magnitude):
    # Load the mesh
    mesh = trimesh.load_mesh(mesh_path)
    mesh.fill_holes()  # Ensure a fully connected mesh

    # Extract vertices and faces
    vertices = np.array(mesh.vertices)
    faces = np.array(mesh.faces)

    # Ensure vertex indices match the volume shape
    volume_shape = np.array(volume.shape)
    voxel_indices = np.round(vertices).astype(int)  # Convert vertex positions to voxel indices
    voxel_indices = np.clip(voxel_indices, 0, volume_shape - 1)  # Keep within bounds

    # Directly sample gradient values at mesh vertices
    vertex_gradients = gradient_magnitude[voxel_indices[:, 0], voxel_indices[:, 1], voxel_indices[:, 2]]

    # Normalize for better visualization
    vertex_gradients = (vertex_gradients - vertex_gradients.min()) / (vertex_gradients.max() - vertex_gradients.min() + 1e-8)

    return vertices, faces, vertex_gradients


In [298]:
import open3d as o3d
import plotly.figure_factory as ff
import trimesh

def create_mesh(volume, path, threshold):
  verts, faces, normals, values = measure.marching_cubes(volume, level=threshold)

  # Convert vertices from marching cubes to Open3D point cloud
  pcd = o3d.geometry.PointCloud()
  pcd.points = o3d.utility.Vector3dVector(verts)

  pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=5, max_nn=30))

  # BPA
  distances = pcd.compute_nearest_neighbor_distance()
  avg_dist = np.mean(distances)
  radius = 3 * avg_dist

  bpa_mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(pcd,o3d.utility.DoubleVector([radius, radius * 3]))
  o3d.io.write_triangle_mesh(path, bpa_mesh)

def create_mesh_fig(path):
  # Load the mesh from the saved file
  mesh = trimesh.load_mesh(path)

  mesh.fill_holes()

  # Extract vertices and faces
  vertices = np.array(mesh.vertices)
  faces = np.array(mesh.faces)

  # Create a 3D Plotly mesh
  fig = ff.create_trisurf(
      x=vertices[:, 0],
      y=vertices[:, 1],
      z=vertices[:, 2],
      simplices=faces,
      colormap="Viridis",
      show_colorbar=True
  )

  return fig

def fast_trisurf_fig(vertices, faces, vertex_gradients, colormap_name="Viridis"):
    # Compute per-face intensity as the mean of its vertices
    face_gradients = np.mean(vertex_gradients[faces], axis=1)

    # Normalize face gradients with adaptive contrast boost
    min_val, max_val = np.percentile(face_gradients, [5, 95])  # Ignore outliers
    face_gradients = (face_gradients - min_val) / (max_val - min_val + 1e-8)
    face_gradients = np.clip(face_gradients * 1.5, 0, 1)

    # FIX: Directly pass face_gradients as numeric values for color_func
    fig = ff.create_trisurf(
        x=vertices[:, 0],
        y=vertices[:, 1],
        z=vertices[:, 2],
        simplices=faces,
        colormap=colormap_name,  # Ensure lowercase for compatibility
        show_colorbar=True,
        color_func=face_gradients  # ✅ Correct format (numerical values, not RGB strings)
    )

    return fig


#### **Visualize example**

**Extra Information**

To provide a detailed anatomical view, the MRI scan is visualized in multiple orientations:

* **Axial View** – Slices the body horizontally, from head to toe.
* **Sagittal View** – Divides the body into left and right portions.
* **Coronal View** – Slices the body from front to back.


In [299]:
volume = nib.load(SAMPLE_IMAGE_PATH)
mask = nib.load(SAMPLE_MASK_PATH)

In [300]:
volume, mask = process(volume, mask)

Scaling factors applied: [1, 1, np.float32(2.9714327)]
Scaling factors applied: [1, 1, np.float32(2.9714327)]


In [301]:
volume = volume.squeeze(0).cpu().numpy()
mask = mask.squeeze(0).cpu().numpy()

In [302]:
from skimage.morphology import binary_closing, ball
mask = binary_closing(mask, ball(2))

In [303]:
segmentation = apply_mask_to_volume(volume, mask)

In [304]:
# Most points to get a full surface
volume_thresh_otsu = threshold_otsu(volume)
volume_thresh = min(np.percentile(volume[volume > 0], 60), volume_thresh_otsu)
print(f"volume_thresh {volume_thresh}")

# Only important points
mask_thresh_otsu = threshold_otsu(segmentation)
mask_thresh = max(np.percentile(segmentation[segmentation > 0], 75), mask_thresh_otsu)
print(f"mask_thresh {mask_thresh}")

volume_thresh 0.2567586302757263
mask_thresh 0.4671536684036255


In [305]:
smoothed_segmentation = gaussian_filter(segmentation, sigma=1)
smoothed_volume =  gaussian_filter(volume, sigma=1)

In [306]:
gradient_magnitude_image = compute_gradient_magnitude(smoothed_volume)
gradient_magnitude_segmentation = compute_gradient_magnitude(smoothed_segmentation)

In [307]:
mesh_path = os.path.join(BASE_PATH, 'volume_mesh.ply')
seg_mesh_path = os.path.join(BASE_PATH, 'segmentation_mesh.ply')

In [308]:
create_mesh(smoothed_volume, mesh_path, threshold=volume_thresh)
create_mesh(smoothed_segmentation, seg_mesh_path, threshold=mask_thresh)

In [309]:
vertices, faces, vertex_gradients = compute_vertex_gradients(mesh_path, smoothed_volume, gradient_magnitude_image)
vertices_seg, faces_seg, vertex_gradients_seg = compute_vertex_gradients(seg_mesh_path, smoothed_segmentation, gradient_magnitude_segmentation)

In [310]:
mesh_fig = fast_trisurf_fig(vertices, faces, vertex_gradients)
seg_mesh_fig = fast_trisurf_fig(vertices_seg, faces_seg, vertex_gradients_seg)

In [311]:
show_figures_side_by_side(mesh_fig, seg_mesh_fig, titles=['Heart', 'Segmentation'], headline="3D Reconstruction")

Output hidden; open in https://colab.research.google.com to view.

**Figures Creation**

In [272]:
views_fig = create_slice_views(volume)
views_fig.show()

# References
* Dataset from ACDC challenge:
  https://humanheart-project.creatis.insa-lyon.fr/database/

  *O. Bernard, A. Lalande, C. Zotti, F. Cervenansky, et al.
"Deep Learning Techniques for Automatic MRI Cardiac Multi-structures Segmentation and Diagnosis: Is the Problem Solved ?" in IEEE Transactions on Medical Imaging, vol. 37, no. 11, pp. 2514-2525, Nov. 2018
doi: 10.1109/TMI.2018.2837502*
* Z-scale importance aticle:
  https://pmc.ncbi.nlm.nih.gov/articles/PMC4648228/
* Systolic and Diastolic end frame:
  https://www.researchgate.net/figure/End-diastolic-ED-and-End-systolic-ES-frames-with-corresponding-electrocardiogram_fig1_329903092
* Heart conditions:
  * https://emedicine.medscape.com/article/152696-overview
  * https://pmc.ncbi.nlm.nih.gov/articles/PMC4110607
  * https://www.ahajournals.org/doi/10.1161/CIRCIMAGING.123.016090
  * https://www.ncbi.nlm.nih.gov/books/NBK553855/
* MRI Scan Normalization:
  * https://medium.com/@susanne.schmid/image-normalization-in-medical-imaging-f586c8526bd1#:~:text=Min%2DMax%20normalization%2C%20also%20known,not%20affect%20the%20image%20itself.
  * https://arxiv.org/pdf/2406.01736#:~:text=Z%2Dscore%20normalization%20is%20a,compare%20different%20Page%209%20images

* Visualization with Open3D: https://orbi.uliege.be/bitstream/2268/254933/1/TDS_generate_3D_meshes_with_python.pdf