In [1]:
# Move to working directory
%cd ..

/root/huy/BrainSegmentation


In [2]:
import os
from tqdm import tqdm
import numpy as np
import dagshub
dagshub.init(repo_owner='huytrnq', repo_name='BrainSegmentation', mlflow=True)

import torch
from torch.utils.data import DataLoader
import torchio as tio
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torchio")

from utils.dataset import BrainMRIDataset, BrainMRISliceDataset
from utils.predict import Predictor
from utils.metric import dice_score_3d, hausdorff_distance, average_volumetric_difference
from utils.utils import export_to_nii, evaluate_segmentation

In [8]:
ROOT_DIR = './Data'
BATCH_SIZE = 16
NUM_CLASSES = 4
NUM_WORKERS = 16
DEVICE = 'mps' if torch.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

In [9]:
## Transforms
val_transform = tio.Compose([
    tio.RescaleIntensity((0, 1)),
    tio.ZNormalization(),
])
## Dataset
val_dataset = BrainMRIDataset(os.path.join(ROOT_DIR, 'val'), transform=val_transform)

## Patch-Based

In [5]:
# Load model using MLflow
predictor_patchs = []
run_ids = ['94c24db36be94ebe947cdaf160c07409', '1bbaa5b686a1493bbe6a6fd83bfed272', 'dacc0d9816cc4ec5859f3e227a8bba9c', '5cabc56f7b374afd8b1c8ae12a190312', '518b7a88dad84ad788ba9d82ad81b4bd']
patch_sizes = [64, 64, 64, 128, 128]
for run_id, patch_size  in zip(run_ids, patch_sizes):
    predictor_patch = Predictor(mlflow_model_uri=f"runs:/{run_id}/model", device=DEVICE, patch_size=patch_size)
    predictor_patchs.append(predictor_patch)

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

In [10]:
patch_predictions = []
masks = torch.stack([subject['mask'][tio.DATA] for subject in val_dataset], dim=0).squeeze(1)

overlaps = [32, 32, 32, 64, 64]
for i, (predictor_patch, overlap) in enumerate(zip(predictor_patchs, overlaps)):
    print('\nRun ID: ', run_ids[i])
    predictions = []
    for subject in tqdm(val_dataset):
        prediction = predictor_patch.predict_patches(subject, batch_size=BATCH_SIZE, overlap=overlap, proba=True)
        predictions.append(prediction)

    # Stack all predictions
    predictions = torch.stack(predictions, dim=0)
    patch_predictions.append(predictions)
    r_dict = evaluate_segmentation(torch.argmax(predictions, dim=1), masks, num_classes=NUM_CLASSES)


Run ID:  94c24db36be94ebe947cdaf160c07409


100%|██████████| 5/5 [01:55<00:00, 23.01s/it]


Dice scores: {0: 0.9976367950439453, 1: 0.9245912432670593, 2: 0.9475903511047363, 3: 0.943122386932373}
Mean Dice score: 0.9384346604347229
Hausdorff Distances: {0: 16.673818588256836, 1: 10.618033409118652, 2: 8.398383140563965, 3: 10.22941780090332}
Mean Hausdorff Distance: 9.748611450195312
Average Volumetric Differences: {0: 0.0014792930540379686, 1: 0.02077716098307194, 2: 0.011930382056820154, 3: 0.007784525781830998}
Mean Average Volumetric Difference: 0.013497356273907697

Run ID:  1bbaa5b686a1493bbe6a6fd83bfed272


100%|██████████| 5/5 [02:01<00:00, 24.39s/it]


Dice scores: {0: 0.9976142644882202, 1: 0.923819363117218, 2: 0.9470298886299133, 3: 0.942984938621521}
Mean Dice score: 0.9379447301228842
Hausdorff Distances: {0: 16.114290237426758, 1: 8.879219055175781, 2: 7.602534294128418, 3: 7.2602949142456055}
Mean Hausdorff Distance: 7.914016087849935
Average Volumetric Differences: {0: 0.0017271902530647289, 1: 0.011670631773570308, 2: 0.0145210166330571, 3: 0.00736380206881939}
Mean Average Volumetric Difference: 0.011185150158482265

Run ID:  dacc0d9816cc4ec5859f3e227a8bba9c


100%|██████████| 5/5 [01:51<00:00, 22.26s/it]


Dice scores: {0: 0.997552752494812, 1: 0.9232271313667297, 2: 0.9455758333206177, 3: 0.941463828086853}
Mean Dice score: 0.9367555975914001
Hausdorff Distances: {0: 15.5580472946167, 1: 8.98314380645752, 2: 8.427544593811035, 3: 7.313634395599365}
Mean Hausdorff Distance: 8.241440931955973
Average Volumetric Differences: {0: 0.0014532187510752862, 1: 0.029738302933786175, 2: 0.01643509254680129, 3: 0.0013895253440545546}
Mean Average Volumetric Difference: 0.01585430694154734

Run ID:  5cabc56f7b374afd8b1c8ae12a190312


100%|██████████| 5/5 [00:11<00:00,  2.34s/it]


Dice scores: {0: 0.9977170825004578, 1: 0.9252330660820007, 2: 0.9494802355766296, 3: 0.9438269734382629}
Mean Dice score: 0.9395134250322977
Hausdorff Distances: {0: 14.598358154296875, 1: 12.326481819152832, 2: 8.476136207580566, 3: 8.453357696533203}
Mean Hausdorff Distance: 9.751991907755533
Average Volumetric Differences: {0: 0.000978059676645272, 1: 0.018781390430626734, 2: 0.007684422591830999, 3: 0.005771305960514641}
Mean Average Volumetric Difference: 0.01074570632765746

Run ID:  518b7a88dad84ad788ba9d82ad81b4bd


100%|██████████| 5/5 [00:11<00:00,  2.27s/it]


Dice scores: {0: 0.997692883014679, 1: 0.9277375340461731, 2: 0.9454866647720337, 3: 0.9381093978881836}
Mean Dice score: 0.9371111989021301
Hausdorff Distances: {0: 15.660593032836914, 1: 14.820938110351562, 2: 58.56317138671875, 3: 7.6742119789123535}
Mean Hausdorff Distance: 27.01944049199422
Average Volumetric Differences: {0: 0.0014271991112215403, 1: 0.019376156489302458, 2: 0.02003545258280387, 3: 0.009583403927734685}
Mean Average Volumetric Difference: 0.016331670999947003


In [11]:
patch_predictions = torch.stack(patch_predictions, dim=0)

In [15]:
### Ensemble
ensemble_predictions = patch_predictions.mean(dim=0)
ensemble_predictions = torch.argmax(ensemble_predictions, dim=1)
patch_r_dict = evaluate_segmentation(ensemble_predictions, masks, num_classes=NUM_CLASSES)

Dice scores: {0: 0.9978324174880981, 1: 0.9304620623588562, 2: 0.9516526460647583, 3: 0.9470009803771973}
Mean Dice score: 0.9430385629336039
Hausdorff Distances: {0: 15.296363830566406, 1: 8.556015968322754, 2: 7.9177985191345215, 3: 7.009108543395996}
Mean Hausdorff Distance: 7.827641010284424
Average Volumetric Differences: {0: 0.0014182343613559638, 1: 0.02130584192411703, 2: 0.012823655353856137, 3: 0.004749061047008058}
Mean Average Volumetric Difference: 0.012959519441660407


## Full Volume Based

In [6]:
## Transforms
val_transform = tio.Compose([
    tio.RescaleIntensity((0, 1)),
    tio.ZNormalization(),
])
## Dataset
val_dataset = BrainMRIDataset(os.path.join(ROOT_DIR, 'val'), transform=val_transform)
val_loader = tio.SubjectsLoader(val_dataset, batch_size=1, shuffle=False, num_workers=NUM_WORKERS)

masks = torch.stack([subject['mask'][tio.DATA] for subject in val_dataset], dim=0).squeeze(1)

In [13]:
# Load model using MLflow
predictor_full = Predictor(mlflow_model_uri="runs:/44a7b8c0aacc44f3ab0491abdc1c7826/model", device=DEVICE)

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

In [14]:
## Predict
predictions_full = predictor_full.predict_full_volume(val_loader, proba=True).squeeze(1)
full_r_dict = evaluate_segmentation(torch.argmax(predictions_full, dim=1), masks, num_classes=NUM_CLASSES)

Dice scores: {0: 0.9969556927680969, 1: 0.8819906115531921, 2: 0.9205700755119324, 3: 0.9057410955429077}
Mean Dice score: 0.9027672608693441
Hausdorff Distances: {0: 17.789249420166016, 1: 29.957195281982422, 2: 13.71537971496582, 3: 10.623468399047852}
Mean Hausdorff Distance: 18.0986811319987
Average Volumetric Differences: {0: 0.0011520250208354958, 1: 0.0008855405762505216, 2: 0.0009730856329555721, 3: 0.025873939804655778}
Mean Average Volumetric Difference: 0.00924418867128729


## Cross Validation

In [34]:
KFOLD = 5
patch_size = 128
predictor_cvs = []
cv_run_id = '47240a9f9b9248e089e3ccefc97616d6'
for k in range(KFOLD):
    predictor_cv = Predictor(mlflow_model_uri=f"runs:/{cv_run_id}/models/fold_{k+1}", device=DEVICE, patch_size=patch_size)
    predictor_cvs.append(predictor_cv)

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

In [36]:
cv_predictions = []
masks = torch.stack([subject['mask'][tio.DATA] for subject in val_dataset], dim=0).squeeze(1)

overlap = 64
for i, predictor_cv in enumerate(predictor_cvs):
    cv_predictions = []
    for subject in tqdm(val_dataset):
        prediction = predictor_cv.predict_patches(subject, batch_size=BATCH_SIZE, overlap=overlap, proba=True)
        cv_predictions.append(prediction)

    # Stack all cv_predictions
    predictions = torch.stack(predictions, dim=0)
    cv_predictions.append(predictions)
    r_dict = evaluate_segmentation(torch.argmax(predictions, dim=1), masks, num_classes=NUM_CLASSES)
    print("=====================================\n")

100%|██████████| 5/5 [00:13<00:00,  2.70s/it]


Dice scores: {0: 0.9975025057792664, 1: 0.9196634292602539, 2: 0.9450166821479797, 3: 0.9386361837387085}
Mean Dice score: 0.9344387650489807
Hausdorff Distances: {0: 14.652447700500488, 1: 13.96502685546875, 2: 8.125859260559082, 3: 8.808526039123535}
Mean Hausdorff Distance: 10.299804051717123
Average Volumetric Differences: {0: 0.0018275517210720347, 1: 0.040060798307691506, 2: 0.015015512280160325, 3: 0.009682899400406349}
Mean Average Volumetric Difference: 0.021586403329419395



100%|██████████| 5/5 [00:12<00:00,  2.40s/it]


Dice scores: {0: 0.9975084066390991, 1: 0.920093834400177, 2: 0.9472190737724304, 3: 0.9449449777603149}
Mean Dice score: 0.9374192953109741
Hausdorff Distances: {0: 16.00098991394043, 1: 17.73281478881836, 2: 7.530290126800537, 3: 8.366315841674805}
Mean Hausdorff Distance: 11.2098069190979
Average Volumetric Differences: {0: 0.0021549290804923804, 1: 0.03413957176798653, 2: 0.018204313331296407, 3: 0.009855168704517858}
Mean Average Volumetric Difference: 0.020733017934600265



100%|██████████| 5/5 [00:12<00:00,  2.52s/it]


Dice scores: {0: 0.9974645376205444, 1: 0.9166332483291626, 2: 0.9462777972221375, 3: 0.9426630139350891}
Mean Dice score: 0.9351913531621298
Hausdorff Distances: {0: 14.819279670715332, 1: 13.953651428222656, 2: 8.467630386352539, 3: 8.676583290100098}
Mean Hausdorff Distance: 10.365955034891764
Average Volumetric Differences: {0: 0.0022740399948648856, 1: 0.03707375099078676, 2: 0.0075469522338769725, 3: 0.033792073792645864}
Mean Average Volumetric Difference: 0.026137592339103195



100%|██████████| 5/5 [00:12<00:00,  2.54s/it]


Dice scores: {0: 0.9975630044937134, 1: 0.9168203473091125, 2: 0.9461957216262817, 3: 0.9386836290359497}
Mean Dice score: 0.9338998993237814
Hausdorff Distances: {0: 15.70568561553955, 1: 11.619071960449219, 2: 8.108404159545898, 3: 8.50373649597168}
Mean Hausdorff Distance: 9.410404205322266
Average Volumetric Differences: {0: 0.0017708934086594135, 1: 0.04344435633038007, 2: 0.004285155558785965, 3: 0.030128934761654257}
Mean Average Volumetric Difference: 0.025952815550273428



100%|██████████| 5/5 [00:13<00:00,  2.71s/it]


Dice scores: {0: 0.9974303245544434, 1: 0.9126714468002319, 2: 0.9439105987548828, 3: 0.9401819109916687}
Mean Dice score: 0.9322546521822611
Hausdorff Distances: {0: 17.27603530883789, 1: 17.70833969116211, 2: 7.98715353012085, 3: 9.23741340637207}
Mean Hausdorff Distance: 11.644302209218344
Average Volumetric Differences: {0: 0.0018804109474135733, 1: 0.0495374041759247, 2: 0.014048391270173833, 3: 0.01312601130040405}
Mean Average Volumetric Difference: 0.025570602248834198



In [37]:
cv_predictions = torch.stack(cv_predictions, dim=0)

In [40]:
### Ensemble
cv_ensemble_predictions = cv_predictions.mean(dim=0)
cv_r_dict = evaluate_segmentation(torch.argmax(cv_ensemble_predictions, dim=1), masks, num_classes=NUM_CLASSES)

Dice scores: {0: 0.9976717829704285, 1: 0.924658477306366, 2: 0.9498182535171509, 3: 0.9454354047775269}
Mean Dice score: 0.9399707118670145
Hausdorff Distances: {0: 14.265218734741211, 1: 13.700634956359863, 2: 7.729846000671387, 3: 7.69365930557251}
Mean Hausdorff Distance: 9.708046754201254
Average Volumetric Differences: {0: 0.0019875233093745296, 1: 0.03683584456731647, 2: 0.010816985769258613, 3: 0.021276111876108648}
Mean Average Volumetric Difference: 0.022976314070894576


## Slice-Based

In [8]:
BATCH_SIZE = 16
NUM_CLASSES = 4
N_TEST = 5

In [10]:
import albumentations as A
from utils.transforms import RobustZNormalization
from albumentations.pytorch import ToTensorV2

test_transform = A.Compose([
        A.Normalize(mean=(0,), std=(1,), max_pixel_value=1.0, p=1.0),
        RobustZNormalization(),
        ToTensorV2()
], additional_targets={'mask': 'mask'})

  check_for_updates()


### Axial

In [11]:
axial_dataset = BrainMRISliceDataset(os.path.join(ROOT_DIR, 'val'), slice_axis=0, transform=test_transform, cache=True, ignore_background=False)
axial_loader = DataLoader(axial_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
axial_labels = torch.cat([mask for _, mask, _, _ in axial_loader], dim=0).squeeze(1).reshape(N_TEST, -1, 128, 256)

In [12]:
axial_predictor = Predictor(mlflow_model_uri="runs:/bb8ff770bd7f495e9151a575eda3624a/model", device=DEVICE)
axial_probs = axial_predictor.predice_slices(axial_loader, proba=True, plane='axial')

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

### Coronal

In [13]:
coronal_dataset = BrainMRISliceDataset(os.path.join(ROOT_DIR, 'val'), slice_axis=1, transform=test_transform, cache=True, ignore_background=False)
coronal_loader = DataLoader(coronal_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
coronal_labels = torch.cat([mask for _, mask, _, _ in coronal_loader], dim=0).squeeze(1).reshape(N_TEST, -1, 256, 256).permute(0, 2, 1, 3)

In [14]:
coronal_predictor = Predictor(mlflow_model_uri="runs:/e05a0eacc46146c9a56e70a185e35eed/model", device=DEVICE)
coronal_probs = coronal_predictor.predice_slices(coronal_loader, proba=True, plane='coronal')

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

### Sagittal

In [15]:
sagittal_dataset = BrainMRISliceDataset(os.path.join(ROOT_DIR, 'val'), slice_axis=2, transform=test_transform, cache=True, ignore_background=False)
sagittal_loader = DataLoader(sagittal_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
sagittal_labels = torch.cat([mask for _, mask, _, _ in sagittal_loader], dim=0).squeeze(1).reshape(N_TEST, -1, 256, 128).permute(0, 2, 3, 1)

In [16]:
sagittal_predictor = Predictor(mlflow_model_uri="runs:/bdc118531a4a4e0fb4ffa9dc8fb0d83d/model", device=DEVICE)
sagittal_probs = sagittal_predictor.predice_slices(sagittal_loader, proba=True, plane='sagittal')

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

### Ensemble

In [26]:
slice_probs = (axial_probs + coronal_probs + sagittal_probs) / 3
slice_predictions = torch.argmax(slice_probs, dim=1)

In [18]:
slice_r_dict = evaluate_segmentation(slice_predictions, axial_labels, NUM_CLASSES, include_background=True)

Dice scores: {0: 0.9975493550300598, 1: 0.9201601147651672, 2: 0.9435482025146484, 3: 0.9394735097885132}
Mean Dice score: 0.9343939423561096
Hausdorff Distances: {0: 17.761503219604492, 1: 11.132668495178223, 2: 9.03249740600586, 3: 7.8488569259643555}
Mean Hausdorff Distance: 9.338007609049479
Average Volumetric Differences: {0: 0.0003300011886492948, 1: 0.000978059740933412, 2: 0.004361559538723224, 3: 0.015637277030636852}
Mean Average Volumetric Difference: 0.006992298770097829


## Ensemble Validation

### Patch-Based + Slice-Based

In [30]:
# Concat patch_predictions and slice_predictions
final_predictions = torch.cat([patch_predictions, slice_probs.unsqueeze(0)], dim=0)
final_predictions = final_predictions.mean(dim=0)
final_predictions = torch.argmax(final_predictions, dim=1)
slice_r_dict = evaluate_segmentation(final_predictions, masks, NUM_CLASSES, include_background=True)

Dice scores: {0: 0.9978084564208984, 1: 0.9305034875869751, 2: 0.9511961936950684, 3: 0.9458341598510742}
Mean Dice score: 0.9425112803777059
Hausdorff Distances: {0: 15.172632217407227, 1: 8.432571411132812, 2: 7.747807502746582, 3: 7.054324150085449}
Mean Hausdorff Distance: 7.744901021321614
Average Volumetric Differences: {0: 0.0014154465428002055, 1: 0.016217287866558062, 2: 0.010215410938377148, 3: 0.009693133220452578}
Mean Average Volumetric Difference: 0.012041944008462596


### Patch-Based + Cross Validation

In [41]:
# Concat patch_predictions and slice_predictions
final_predictions = torch.cat([patch_predictions, cv_ensemble_predictions.unsqueeze(0)], dim=0)
final_predictions = final_predictions.mean(dim=0)
final_predictions = torch.argmax(final_predictions, dim=1)
slice_r_dict = evaluate_segmentation(final_predictions, masks, NUM_CLASSES, include_background=True)

Dice scores: {0: 0.9978310465812683, 1: 0.9303766489028931, 2: 0.9519494771957397, 3: 0.9474298357963562}
Mean Dice score: 0.9432519872983297
Hausdorff Distances: {0: 15.282182693481445, 1: 8.485294342041016, 2: 7.990594387054443, 3: 7.049042701721191}
Mean Hausdorff Distance: 7.841643810272217
Average Volumetric Differences: {0: 0.0015141954490938863, 1: 0.02392281258229021, 2: 0.01226780929907095, 3: 0.00797044017933748}
Mean Average Volumetric Difference: 0.014720354020232878


## Prediction

In [8]:
## Transforms
test_transform = tio.Compose([
    tio.RescaleIntensity((0, 1)),
    tio.ZNormalization(),
])
## Dataset
test_dataset = BrainMRIDataset(os.path.join(ROOT_DIR, 'test'), transform=test_transform)

In [9]:
predictions = []

overlaps = [32, 32, 32, 64, 64]
for predictor_patch, overlap in zip(predictor_patchs, overlaps):
    patch_predictions = []
    for subject in tqdm(test_dataset):
        prediction = predictor_patch.predict_patches(subject, batch_size=BATCH_SIZE, overlap=overlap, proba=True)
        patch_predictions.append(prediction)

    # Stack all patch_predictions
    patch_predictions = torch.stack(patch_predictions, dim=0)
    predictions.append(patch_predictions)

### Ensemble
ensemble_predictions = torch.stack(predictions, dim=0)
ensemble_predictions = ensemble_predictions.mean(dim=0)
ensemble_predictions = torch.argmax(ensemble_predictions, dim=1)

  0%|          | 0/3 [00:00<?, ?it/s]

100%|██████████| 3/3 [00:34<00:00, 11.58s/it]
100%|██████████| 3/3 [00:37<00:00, 12.39s/it]
100%|██████████| 3/3 [00:36<00:00, 12.21s/it]
100%|██████████| 3/3 [00:04<00:00,  1.56s/it]
100%|██████████| 3/3 [00:04<00:00,  1.58s/it]


In [10]:
for i, subject in enumerate(test_dataset):
    affine = subject['image'].affine
    spacing = subject['image'].spacing
    name = subject['image'].path.name
    export_to_nii(ensemble_predictions[i].numpy().astype(np.int16), f'./results/{name}', spacing, affine)

Saved NIfTI file to ./results/IBSR_02.nii.gz
Saved NIfTI file to ./results/IBSR_10.nii.gz
Saved NIfTI file to ./results/IBSR_15.nii.gz
