In [None]:
# Re-import after kernel reset
import nibabel as nib
import numpy as np
import plotly.graph_objects as go
import os
import albumentations as A


class ImageReader:
    def __init__(self, root: str, img_size: int = 256, normalize: bool = False):
        pad_size = 256 if img_size > 256 else 224
        self.resize = A.Compose([
            A.PadIfNeeded(min_height=pad_size, min_width=pad_size, value=0),
            A.Resize(img_size, img_size)
        ])
        self.normalize = normalize
        self.root = root

    def read_file(self, path: str, with_seg: bool = False) -> dict:
        raw_image = nib.load(path).get_fdata()
        scan_type = path.split('_')[-1].split('.')[0]
        processed_frames, masks = [], []

        if with_seg:
            mask_path = path.replace(scan_type, 'seg')
            raw_mask = nib.load(mask_path).get_fdata()
        else:
            raw_mask = None

        for frame_idx in range(raw_image.shape[2]):
            frame = raw_image[:, :, frame_idx]
            if with_seg:
                mask = raw_mask[:, :, frame_idx]
                resized = self.resize(image=frame, mask=mask)
                processed_frames.append(resized['image'])
                masks.append(resized['mask'])
            else:
                resized = self.resize(image=frame)
                processed_frames.append(resized['image'])

        scan_data = np.stack(processed_frames, 0)

        if self.normalize and scan_data.max() > 0:
            scan_data = scan_data / scan_data.max()
            scan_data = scan_data.astype(np.float32)

        result = {'scan': scan_data, 'orig_shape': raw_image.shape}
        if with_seg:
            result['seg'] = np.stack(masks, 0)

        return result

    def load_patient_scan(self, patient_id: str, scan_type: str = 'flair', with_seg: bool = False) -> dict:
        scan_path = os.path.join(self.root, patient_id, f"{patient_id}_{scan_type}.nii")
        return self.read_file(scan_path, with_seg=with_seg)


class ImageViewer3d:
    def __init__(self, reader: ImageReader, mri_downsample: int = 2, threshold: float = 0.3
        self.downsample = mri_downsample
        self.threshold = threshold

    def get_3d_scan(self, patient_id: str, scan_type: str = 'flair') -> go.Figure:
        data = self.reader.load_patient_scan(patient_id, scan_type=scan_type, with_seg=True)
        volume = data['scan']
        mask = data.get('seg', None)

        # Normalize and threshold
        norm = (volume - np.min(volume)) / (np.max(volume) - np.min(volume))
        binary_mask = norm > self.threshold
        subsample = (slice(None, None, self.downsample),) * 3
        vol = norm[subsample]
        m = binary_mask[subsample]
        x, y, z = np.where(m)
        intensities = vol[x, y, z]

        fig = go.Figure()

        fig.add_trace(go.Scatter3d(
            x=x, y=y, z=z,
            mode='markers',
            marker=dict(size=3, color=intensities, colorscale='Viridis', opacity=0.7),
            name=f'{scan_type.upper()} Scan'
        ))

        # Optional: Add segmentation mask overlays
        if mask is not None:
            seg = mask[subsample]
            for label, name, color in zip([1, 2, 4], ['Tumor Core', 'Edema', 'Enhancing Tumor'],
                                          ['Reds', 'Blues', 'Greens']):
                sx, sy, sz = np.where(seg == label)
                fig.add_trace(go.Scatter3d(
                    x=sx, y=sy, z=sz,
                    mode='markers',
                    marker=dict(size=2, color=sx, colorscale=color, opacity=0.6),
                    name=name
                ))

        fig.update_layout(
            title=f"3D MRI + Tumor Segmentation – {patient_id}",
            scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'),
            margin=dict(l=0, r=0, b=0, t=40),
            legend=dict(x=0.8, y=0.9)
        )

        return fig



In [5]:
training_root = "../data/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData"
reader = ImageReader(root=training_root, img_size=128, normalize=True)
viewer = ImageViewer3d(reader, mri_downsample=2, threshold=0.3)
fig = viewer.get_3d_scan("BraTS20_Training_001", scan_type='flair')
fig.show()


  A.PadIfNeeded(min_height=pad_size, min_width=pad_size, value=0),
