# Visualize 3D WT/TC/ET (GT vs Pred) with Brain Context

Marching-cubes 3D view for WT / TC / ET using only masks. Optional brain surface from FLAIR for context. Everything is bbox-cropped + downsampled to save RAM. You can also hide axes for a cleaner view.


In [44]:
from pathlib import Path
import numpy as np
import nibabel as nib
from skimage import measure
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Paths & options (edit as needed)
DATA_ROOT = Path('../data/processed/3d/labeled')
PRED_ROOT = Path('../experiments/brats3d_vnet_sup/inference_diceloss/preds')
CASE_ID = 'Brain_077'

SAMPLE_STEP = 2        # stride for WT/TC/ET marching cubes (>=1)
BBOX_MARGIN = 2        # extra voxels around ROI bbox to avoid clipping

SHOW_BRAIN = True      # load FLAIR and add a brain mask surface
BRAIN_SAMPLE_STEP = 4  # stride for brain mask (higher to save memory)
BRAIN_COLOR = "#97b0cf"
BRAIN_OPACITY = 0.15

SHOW_AXES = False      # False hides axes/ticks/grids for a clean look


In [45]:
# Helpers

def load_seg(path: Path):
    """Read NIfTI mask (0,1,2,4) -> map 4 -> 3, return array + spacing."""
    nii = nib.load(str(path))
    arr = nii.get_fdata().astype(np.int16)
    arr[arr == 4] = 3
    spacing = tuple(float(z) for z in nii.header.get_zooms()[:3])
    return arr, spacing


def load_flair(case_dir: Path) -> np.ndarray:
    """Load FLAIR volume; used only to derive brain mask."""
    nii = nib.load(str(case_dir / 't1.nii.gz'))
    vol = nii.get_fdata().astype(np.float32)
    return vol


def make_roi_masks(seg: np.ndarray):
    return {
        'WT': seg > 0,
        'TC': np.isin(seg, [1, 3]),
        'ET': seg == 3,
    }


def compute_bbox(mask: np.ndarray):
    coords = np.argwhere(mask)
    if coords.size == 0:
        return None
    mins = coords.min(axis=0)
    maxs = coords.max(axis=0) + 1
    return mins, maxs


def crop_and_downsample(mask: np.ndarray, step: int, margin: int):
    bbox = compute_bbox(mask)
    if bbox is None:
        return None, None
    (z0, y0, x0), (z1, y1, x1) = bbox
    z0 = max(0, z0 - margin)
    y0 = max(0, y0 - margin)
    x0 = max(0, x0 - margin)
    z1 = min(mask.shape[0], z1 + margin)
    y1 = min(mask.shape[1], y1 + margin)
    x1 = min(mask.shape[2], x1 + margin)
    cropped = mask[z0:z1, y0:y1, x0:x1]
    if step > 1:
        cropped = cropped[::step, ::step, ::step]
    offset = np.array([z0, y0, x0], dtype=np.float32)
    return cropped, offset


def mask_to_mesh_trace(mask: np.ndarray, spacing, color: str, name: str, opacity: float, step: int, margin: int):
    if mask is None or mask.sum() == 0:
        return None
    prepared, offset = crop_and_downsample(mask, step=step, margin=margin)
    if prepared is None or prepared.sum() == 0:
        return None
    step_spacing = (spacing[0] * step, spacing[1] * step, spacing[2] * step)
    verts, faces, _, _ = measure.marching_cubes(prepared.astype(np.float32), level=0.5, spacing=step_spacing)
    verts += offset * np.array(spacing)
    return go.Mesh3d(
        x=verts[:, 2], y=verts[:, 1], z=verts[:, 0],
        i=faces[:, 2], j=faces[:, 1], k=faces[:, 0],
        color=color, opacity=opacity, name=name,
        flatshading=True, showscale=False,
    )


def build_figure(
    gt_seg: np.ndarray,
    pred_seg: np.ndarray,
    spacing,
    sample_step: int = 1,
    margin: int = 2,
    brain_mask=None,
    show_axes: bool = True
):
    # Màu cho GT (tất cả WT/TC/ET chung 1 màu xanh lá)
    GT_COLOR = "#0CDA3C"

    # Màu cho prediction từng vùng (3 màu khác nhau)
    PRED_COLORS = {
        "WT": "#EA41F9",  # đỏ
        "TC": "#DA6675",  # xanh dương
        "ET": "#F8F41E",  # cam
    }

    fig = make_subplots(
        rows=1, cols=3,
        specs=[[{'type': 'scene'}] * 3],
        subplot_titles=['WT (GT + Pred)', 'TC (GT + Pred)', 'ET (GT + Pred)'],
        horizontal_spacing=0.05,
    )

    gt_masks = make_roi_masks(gt_seg)
    pred_masks = make_roi_masks(pred_seg)

    for j, roi in enumerate(['WT', 'TC', 'ET'], start=1):
        # Optional brain context first so tumor sits on top
        if brain_mask is not None:
            brain_trace = mask_to_mesh_trace(
                mask=brain_mask,
                spacing=spacing,
                color=BRAIN_COLOR,
                name='Brain',
                opacity=BRAIN_OPACITY,
                step=max(1, int(BRAIN_SAMPLE_STEP)),
                margin=0,
            )
            if brain_trace is not None:
                fig.add_trace(brain_trace, row=1, col=j)

        # --- GT: luôn màu xanh lá, cùng màu cho cả 3 ROI ---
        gt_trace = mask_to_mesh_trace(
            mask=gt_masks[roi],
            spacing=spacing,
            color=GT_COLOR,
            name=f'GT {roi}',
            opacity=0.55,
            step=max(1, int(sample_step)),
            margin=max(0, int(margin)),
        )
        if gt_trace is not None:
            fig.add_trace(gt_trace, row=1, col=j)

        # --- Pred: mỗi ROI 1 màu khác nhau ---
        pred_color = PRED_COLORS.get(roi, "#F916CB")
        pred_trace = mask_to_mesh_trace(
            mask=pred_masks[roi],
            spacing=spacing,
            color=pred_color,
            name=f'Pred {roi}',
            opacity=0.45,
            step=max(1, int(sample_step)),
            margin=max(0, int(margin)),
        )
        if pred_trace is not None:
            fig.add_trace(pred_trace, row=1, col=j)

    fig.update_layout(
        height=720,
        width=1600,
        title_text=f'Case: {CASE_ID} – 3D WT / TC / ET (GT vs Pred)',
        legend=dict(
            orientation='h',
            yanchor='bottom',
            y=0.02,
            x=0.5,
            xanchor='center'
        ),
    )

    axis_cfg = dict(
        title='',
        showticklabels=show_axes,
        visible=show_axes,
        showgrid=False,
        zeroline=False,
    )

    for j in range(1, 4):
        fig.update_scenes(
            aspectmode='data',
            xaxis=axis_cfg | {'title': 'X'} if show_axes else axis_cfg,
            yaxis=axis_cfg | {'title': 'Y'} if show_axes else axis_cfg,
            zaxis=axis_cfg | {'title': 'Z'} if show_axes else axis_cfg,
            dragmode='turntable',
            bgcolor='white',
            row=1,
            col=j,
        )
    return fig


In [46]:
# Load masks
case_dir = DATA_ROOT / CASE_ID
pred_path = PRED_ROOT / f'{CASE_ID}_pred.nii.gz'

gt_seg, spacing = load_seg(case_dir / 'mask.nii.gz')
pred_seg, _ = load_seg(pred_path)

brain_mask = None
if SHOW_BRAIN:
    flair = load_flair(case_dir)
    brain_mask = flair > 0  # simple brain mask from non-zero FLAIR
    print(f'Brain mask voxels: {int(brain_mask.sum())}')

print(f'GT shape: {gt_seg.shape}, unique={np.unique(gt_seg)}')
print(f'Pred shape: {pred_seg.shape}, unique={np.unique(pred_seg)}')


Brain mask voxels: 1298084
GT shape: (194, 194, 155), unique=[0 1 2 3]
Pred shape: (194, 194, 155), unique=[0 1 2 3]


In [47]:
# Render figure
fig = build_figure(
    gt_seg, pred_seg,
    spacing=spacing,
    sample_step=SAMPLE_STEP,
    margin=BBOX_MARGIN,
    brain_mask=brain_mask,
    show_axes=SHOW_AXES,
)
fig.show()
