# Imports and MRI Data Loading

In [None]:
#Imports for preprocessing, scikitlearn model imports still needed
import os
import random
import warnings
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

In [None]:
# Load and Organize files
root_dir = '/path/to/dataset/'  #Change this path

patients = {}
masked_patients = {}
for patient_id in sorted(os.listdir(root_dir)):
    patient_path = os.path.join(root_dir, patient_id)
    if not os.path.isdir(patient_path):
        continue

    modalities = {}
    for file in os.listdir(patient_path):
        if file.endswith('.nii'):
            path = os.path.join(patient_path, file)
            img = nib.load(path).get_fdata().astype(np.float32)

            # Identify modality by filename
            if '_t1ce' in file.lower():
                modalities['T1CE'] = img
            elif '_t1' in file.lower() and '_t1ce' not in file.lower():
                modalities['T1'] = img
            elif '_t2' in file.lower():
                modalities['T2'] = img
            elif '_flair' in file.lower():
                modalities['FLAIR'] = img
            elif '_seg' in file.lower():
                modalities['SEG'] = img

    # Optional: check all 5 files per patient are present before storing
    if len(modalities) == 5:
        patients[patient_id] = modalities
        print(f"Loaded {patient_id} → {list(modalities.keys())}")
    else:
        print(f"Skipped {patient_id} (missing modalities)")

print(f"\n Total patients loaded: {len(patients)}")


# Visualization

In [None]:
def visualize_random_patient(patients_dict, slice_idx=None):

    # Pick a random patient
    patient_id = random.choice(list(patients_dict.keys()))
    data = patients_dict[patient_id]
    modalities = ['T1', 'T1CE', 'T2', 'FLAIR', 'SEG']

    # Use middle slice if none given
    if slice_idx is None:
        slice_idx = data['T1'].shape[2] // 2

    # Plot all modalities
    fig, axes = plt.subplots(1, 5, figsize=(22, 5))
    for i, mod in enumerate(modalities):
        img = data[mod][:, :, slice_idx]
        axes[i].imshow(img.T, cmap='gray' if mod != 'SEG' else 'jet', origin='lower')
        axes[i].set_title(mod)
        axes[i].axis('off')

    plt.suptitle(f"Patient: {patient_id} | Slice {slice_idx}", fontsize=16)
    plt.tight_layout()
    plt.subplots_adjust(top=0.85)
    plt.show()

In [None]:
visualize_random_patient(patients)

In [None]:
slice_idx = 77
flair = patients['BraTS20_Training_232']['FLAIR']
seg = patients['BraTS20_Training_232']['SEG']

plt.figure(figsize=(12, 6))
plt.imshow(flair[:, :, slice_idx], cmap='gray')
plt.imshow(seg[:, :, slice_idx], cmap='jet', alpha=0.5)
plt.title('FLAIR + Segmentation')
plt.axis('off')
plt.show()

# Apply Brain Mask

In [None]:
masked_patients = {}

for patient_id, data in patients.items():

    #Create brain mask by combining modalities
    #Any voxel > 0 in ANY modality is considered part of the brain
    mask = np.zeros_like(data['T1'], dtype=bool)
    for mod in ['T1', 'T1CE', 'T2', 'FLAIR']:
        mask |= (data[mod] > 0)

    #Store masked data and mask image
    masked_patients[patient_id] = {
        "masked_modalities": {
            "T1":    data['T1'],
            "T1CE":  data['T1CE'],
            "T2":    data['T2'],
            "FLAIR": data['FLAIR'],
        },
        "mask_img": nib.Nifti1Image(mask.astype(np.uint8), np.eye(4)),
        "SEG" : data["SEG"]
    }

    print(f"{patient_id}: Brain voxels = {mask.sum():,}")

In [None]:
pid = random.choice(list(masked_patients.keys()))
mask = masked_patients[pid]["mask_img"].get_fdata().astype(bool)
flair = masked_patients[pid]["masked_modalities"]["FLAIR"]

z = flair.shape[2] // 2
plt.figure(figsize=(7,7))
plt.imshow((flair * mask)[:, :, z].T, cmap='gray', origin='lower')
plt.title(f"{pid} | Brain Mask")
plt.axis("off")
plt.show()

# Intensity Clipping

In [None]:
def intensity_clipping(volume, mask=None, lower=1, upper=99):

    # If no mask provided, clip across entire non-zero voxels
    if mask is None:
        mask = volume > 0

    # Extract values inside the brain (avoid background = 0)
    brain_voxels = volume[mask]

    # Compute percentile bounds
    p_low, p_high = np.percentile(brain_voxels, [lower, upper])

    # Clip intensities
    clipped = np.clip(volume, p_low, p_high)

    return clipped

In [None]:
for mod in ['T1', 'T1CE', 'T2', 'FLAIR']:
  data[mod] = intensity_clipping(data[mod], mask=mask)

# Z Score Normalization

In [None]:
def zscore_normalization(volume, mask=None):

    if mask is None:
        mask = volume > 0  # avoid background

    brain_voxels = volume[mask]

    mean = brain_voxels.mean()
    std  = brain_voxels.std()

    if std == 0:
        # Extremely rare, but prevents division by zero
        return volume - mean

    normed = (volume - mean) / std
    return normed

In [None]:
for mod in ["T1", "T1CE", "T2", "FLAIR"]:
    data[mod] = zscore_normalization(data[mod], mask=mask)

# Multi-Channel Feature Vector

In [None]:

voxel_data = {}   # stores X matrix per patient

for pid, entry in masked_patients.items():

    mask = entry["mask_img"].get_fdata().astype(bool)
    mods = entry["masked_modalities"]

    # 1. Intensity clipping
    for mod in ["T1", "T1CE", "T2", "FLAIR"]:
        mods[mod] = intensity_clipping(mods[mod], mask=mask)

    # 2. Z-score normalization
    for mod in ["T1", "T1CE", "T2", "FLAIR"]:
        mods[mod] = zscore_normalization(mods[mod], mask=mask)

    masked_patients[pid]["masked_modalities"] = mods

# Crop Volumes

In [None]:
def center_crop(volume, crop_size=128):
    x, y, z = volume.shape
    cx, cy, cz = x//2, y//2, z//2
    half = crop_size // 2

    return volume[
        cx - half : cx + half,
        cy - half : cy + half,
        cz - half : cz + half
    ]

In [None]:
cropped_patients = {}

def get_bbox(mask):

    coords = np.array(np.where(mask))
    zmin, ymin, xmin = coords.min(axis=1)
    zmax, ymax, xmax = coords.max(axis=1)
    return (zmin, zmax, ymin, ymax, xmin, xmax)

for pid, entry in masked_patients.items():
    mods = entry["masked_modalities"]
    seg  = entry["SEG"]
    mask = entry["mask_img"].get_fdata().astype(bool)

    # Compute bounding box ONCE from brain mask
    zmin, zmax, ymin, ymax, xmin, xmax = get_bbox(mask)

    # Crop everything using SAME bounding box
    cropped_mods = {}
    for mod in ["T1", "T1CE", "T2", "FLAIR"]:
        vol = mods[mod]
        cropped_mods[mod] = vol[zmin:zmax+1,
                                ymin:ymax+1,
                                xmin:xmax+1]

    cropped_seg = seg[zmin:zmax+1,
                      ymin:ymax+1,
                      xmin:xmax+1]

    cropped_mask = mask[zmin:zmax+1,
                        ymin:ymax+1,
                        xmin:xmax+1]

    # Save
    cropped_patients[pid] = {
        "masked_modalities": cropped_mods,
        "mask": cropped_mask,
        "SEG": cropped_seg
    }

    print(f"{pid} cropped to {cropped_mods['FLAIR'].shape}")

In [None]:
import numpy as np

roi_patients = {}

for pid, entry in cropped_patients.items():
    mods = entry["masked_modalities"]
    flair = mods["FLAIR"]
    t1ce  = mods["T1CE"]
    mask  = entry["mask"]     # brain mask in cropped space


    roi = (flair > 0) & mask


    t1ce_thr = np.percentile(t1ce[mask], 75)
    roi = roi | (t1ce > t1ce_thr)


    roi_patients[pid] = {
        "roi": roi,
        "SEG": entry["SEG"],
        "masked_modalities": mods,
        "mask": mask
    }

    print(f"{pid}: ROI voxels = {roi.sum()} / {mask.sum()}")

In [None]:
scaler = StandardScaler()
# pca = PCA(n_components=3, random_state=42)   # comment out to disable PCA

voxel_data = {}

for pid, entry in roi_patients.items():
    roi = entry["roi"]
    mods = entry["masked_modalities"]

    X = np.vstack([
        mods["T1"][roi],
        mods["T1CE"][roi],
        mods["T2"][roi],
        mods["FLAIR"][roi]
    ]).T

    # If you want to use pca, do pca.fit instead
    X_std = scaler.fit_transform(X)

    voxel_data[pid] = {
        "X_pca": X_std,  # If using pca change to X_pca
        "roi": roi,
        "SEG": entry["SEG"]
    }

    print(f"{pid}: Using NON-PCA standardized features → {X_std.shape}")

# Dimensionality Reduction

In [None]:
pid = list(voxel_data.keys())[0]
X_pca = voxel_data[pid]["X_pca"]

print(f"\nPCA QC for {pid}")
print("PCA component means", X_pca.mean(axis=0))
print("PCA component stds ", X_pca.std(axis=0))