In [1]:
# Move to working directory
%cd ..

/root/huy/BrainSegmentation


In [None]:
import os
from tqdm import tqdm
import numpy as np
import dagshub
dagshub.init(repo_owner='huytrnq', repo_name='BrainSegmentation', mlflow=True)

import torch
import torchio as tio
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torchio")

from utils.dataset import BrainMRIDataset, BrainMRISliceDataset
from utils.predict import Predictor
from utils.metric import dice_score_3d, hausdorff_distance, average_volumetric_difference
from utils.utils import export_to_nii

In [3]:
ROOT_DIR = './Data'
BATCH_SIZE = 16
NUM_CLASSES = 4
NUM_WORKERS = 16
DEVICE = 'mps' if torch.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

## Patch-Based

In [4]:
## Transforms
val_transform = tio.Compose([
    tio.RescaleIntensity((0, 1)),
    tio.ZNormalization(),
])
## Dataset
val_dataset = BrainMRIDataset(os.path.join(ROOT_DIR, 'val'), transform=val_transform)

In [15]:
# Load model using MLflow
predictor_patchs = []
run_ids = ['94c24db36be94ebe947cdaf160c07409', '1bbaa5b686a1493bbe6a6fd83bfed272', 'dacc0d9816cc4ec5859f3e227a8bba9c', '5cabc56f7b374afd8b1c8ae12a190312', '518b7a88dad84ad788ba9d82ad81b4bd']
# run_ids = ['94c24db36be94ebe947cdaf160c07409']
patch_sizes = [64, 64, 64, 128, 128]
for run_id, patch_size  in zip(run_ids, patch_sizes):
    predictor_patch = Predictor(mlflow_model_uri=f"runs:/{run_id}/model", device=DEVICE, patch_size=patch_size)
    predictor_patchs.append(predictor_patch)

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

In [36]:
predictions = []
masks = torch.stack([subject['mask'][tio.DATA] for subject in val_dataset], dim=0)

overlaps = [32, 32, 32, 64, 64]
for i, (predictor_patch, overlap) in enumerate(zip(predictor_patchs, overlaps)):
    print('Run ID: ', run_ids[i])
    patch_predictions = []
    for subject in tqdm(val_dataset):
        prediction = predictor_patch.predict_patches(subject, batch_size=BATCH_SIZE, overlap=overlap, proba=True)
        patch_predictions.append(prediction)

    # Stack all patch_predictions
    patch_predictions = torch.stack(patch_predictions, dim=0)
    predictions.append(patch_predictions)
    dice = dice_score_3d(torch.argmax(patch_predictions, dim=1), masks.squeeze(1), num_classes=NUM_CLASSES)
    hd = hausdorff_distance(torch.argmax(patch_predictions, dim=1),masks.squeeze(1),num_classes=NUM_CLASSES, include_background=True)
    avd = average_volumetric_difference(torch.argmax(patch_predictions, dim=1), masks.squeeze(1), num_classes=NUM_CLASSES)
    print(f"Dice score: {dice}")
    print("Mean Dice score: ", np.mean(list(dice.values())[1:]))
    print(f"Hausdorff distance: {hd}")
    print("Mean Hausdorff distance: ", np.mean(list(hd.values())[1:]))
    print(f"Average volumetric difference: {avd}")
    print("Mean Average volumetric difference: ", np.mean(list(avd.values())[1:]))
    print("=====================================\n")

Run ID:  94c24db36be94ebe947cdaf160c07409


100%|██████████| 5/5 [03:08<00:00, 37.71s/it]


Dice score: {0: 0.9976367950439453, 1: 0.9245912432670593, 2: 0.9475903511047363, 3: 0.943122386932373}
Mean Dice score:  0.9384346604347229
Hausdorff distance: {0: 16.673818588256836, 1: 10.618033409118652, 2: 8.398383140563965, 3: 10.22941780090332}
Mean Hausdorff distance:  9.748611450195312
Average volumetric difference: {0: 0.0014792930540379686, 1: 0.02077716098307194, 2: 0.011930382056820154, 3: 0.007784525781830998}
Mean Average volumetric difference:  0.013497356273907697
Run ID:  1bbaa5b686a1493bbe6a6fd83bfed272


100%|██████████| 5/5 [03:10<00:00, 38.18s/it]


Dice score: {0: 0.9976142644882202, 1: 0.923819363117218, 2: 0.9470298886299133, 3: 0.942984938621521}
Mean Dice score:  0.9379447301228842
Hausdorff distance: {0: 16.114290237426758, 1: 8.879219055175781, 2: 7.602534294128418, 3: 7.2602949142456055}
Mean Hausdorff distance:  7.914016087849935
Average volumetric difference: {0: 0.0017271902530647289, 1: 0.011670631773570308, 2: 0.0145210166330571, 3: 0.00736380206881939}
Mean Average volumetric difference:  0.011185150158482265
Run ID:  dacc0d9816cc4ec5859f3e227a8bba9c


100%|██████████| 5/5 [03:08<00:00, 37.66s/it]


Dice score: {0: 0.997552752494812, 1: 0.9232271313667297, 2: 0.9455758333206177, 3: 0.941463828086853}
Mean Dice score:  0.9367555975914001
Hausdorff distance: {0: 15.5580472946167, 1: 8.98314380645752, 2: 8.427544593811035, 3: 7.313634395599365}
Mean Hausdorff distance:  8.241440931955973
Average volumetric difference: {0: 0.0014532187510752862, 1: 0.029738302933786175, 2: 0.01643509254680129, 3: 0.0013895253440545546}
Mean Average volumetric difference:  0.01585430694154734
Run ID:  5cabc56f7b374afd8b1c8ae12a190312


100%|██████████| 5/5 [00:18<00:00,  3.75s/it]


Dice score: {0: 0.9977170825004578, 1: 0.9252330660820007, 2: 0.9494802355766296, 3: 0.9438269734382629}
Mean Dice score:  0.9395134250322977
Hausdorff distance: {0: 14.598358154296875, 1: 12.326481819152832, 2: 8.476136207580566, 3: 8.453357696533203}
Mean Hausdorff distance:  9.751991907755533
Average volumetric difference: {0: 0.000978059676645272, 1: 0.018781390430626734, 2: 0.007684422591830999, 3: 0.005771305960514641}
Mean Average volumetric difference:  0.01074570632765746
Run ID:  518b7a88dad84ad788ba9d82ad81b4bd


100%|██████████| 5/5 [00:19<00:00,  3.99s/it]


Dice score: {0: 0.997692883014679, 1: 0.9277375340461731, 2: 0.9454866647720337, 3: 0.9381093978881836}
Mean Dice score:  0.9371111989021301
Hausdorff distance: {0: 15.660593032836914, 1: 14.820938110351562, 2: 58.56317138671875, 3: 7.6742119789123535}
Mean Hausdorff distance:  27.01944049199422
Average volumetric difference: {0: 0.0014271991112215403, 1: 0.019376156489302458, 2: 0.02003545258280387, 3: 0.009583403927734685}
Mean Average volumetric difference:  0.016331670999947003


In [37]:
### Ensemble
ensemble_predictions = torch.stack(predictions, dim=0)
ensemble_predictions = ensemble_predictions.mean(dim=0)
dice = dice_score_3d(torch.argmax(ensemble_predictions, dim=1), masks.squeeze(1), num_classes=NUM_CLASSES)
print(f"Ensemble Dice score: {dice}")
print("Ensemble Mean Dice score: ", np.mean(list(dice.values())[1:]))

Ensemble Dice score: {0: 0.9978324174880981, 1: 0.9304620623588562, 2: 0.9516526460647583, 3: 0.9470009803771973}
Ensemble Mean Dice score:  0.9430385629336039


In [38]:
hd = hausdorff_distance(torch.argmax(ensemble_predictions, dim=1), masks.squeeze(1), num_classes=NUM_CLASSES, include_background=True)
print("Hausdorff Distances:", hd)
print("Average Hausdorff Distance:", np.mean(list(hd.values())[1:]))

Hausdorff Distances: {0: 15.296363830566406, 1: 8.556015968322754, 2: 7.9177985191345215, 3: 7.009108543395996}
Average Hausdorff Distance: 7.827641010284424


In [39]:
avd = average_volumetric_difference(torch.argmax(ensemble_predictions, dim=1), masks.squeeze(1), num_classes=NUM_CLASSES)
print(f"Ensemble Average volumetric difference: {avd}")
print("Mean Ensemble Average volumetric difference: ", np.mean(list(avd.values())[1:]))

Ensemble Average volumetric difference: {0: 0.0014182343613559638, 1: 0.02130584192411703, 2: 0.012823655353856137, 3: 0.004749061047008058}
Mean Ensemble Average volumetric difference:  0.012959519441660407


## Full Volume Based

In [40]:
## Transforms
val_transform = tio.Compose([
    tio.RescaleIntensity((0, 1)),
    tio.ZNormalization(),
])
## Dataset
val_dataset = BrainMRIDataset(os.path.join(ROOT_DIR, 'val'), transform=val_transform)
val_loader = tio.SubjectsLoader(val_dataset, batch_size=1, shuffle=False, num_workers=NUM_WORKERS)

In [5]:
# Load model using MLflow
predictor_full = Predictor(mlflow_model_uri="runs:/1c98c526ea884b768e491a03985c8f22/model", device=DEVICE)

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

In [25]:
## Predict
predictions_full = predictor_full.predict_full_volume(val_loader, proba=False)

In [23]:
dice_score_3d(torch.argmax(predictions_full.squeeze(1), dim=1), masks.squeeze(1), num_classes=NUM_CLASSES)

{0: 0.9966910481452942,
 1: 0.8851995468139648,
 2: 0.923554539680481,
 3: 0.9096390008926392}

## Cross Validation

## Prediction

In [19]:
## Transforms
test_transform = tio.Compose([
    tio.RescaleIntensity((0, 1)),
    tio.ZNormalization(),
])
## Dataset
test_dataset = BrainMRIDataset(os.path.join(ROOT_DIR, 'test'), transform=test_transform)

In [20]:
predictions = []

overlaps = [32, 32, 32, 64]
for predictor_patch, overlap in zip(predictor_patchs, overlaps):
    patch_predictions = []
    for subject in tqdm(test_dataset):
        prediction = predictor_patch.predict_patches(subject, batch_size=BATCH_SIZE, overlap=overlap, proba=True)
        patch_predictions.append(prediction)

    # Stack all patch_predictions
    patch_predictions = torch.stack(patch_predictions, dim=0)
    predictions.append(patch_predictions)

### Ensemble
ensemble_predictions = torch.stack(predictions, dim=0)
ensemble_predictions = ensemble_predictions.mean(dim=0)
ensemble_predictions = torch.argmax(ensemble_predictions, dim=1)

  0%|          | 0/3 [00:00<?, ?it/s]

100%|██████████| 3/3 [00:59<00:00, 19.93s/it]
100%|██████████| 3/3 [00:57<00:00, 19.08s/it]
100%|██████████| 3/3 [01:02<00:00, 20.69s/it]
100%|██████████| 3/3 [00:06<00:00,  2.29s/it]


In [21]:
for i, subject in enumerate(test_dataset):
    affine = subject['image'].affine
    spacing = subject['image'].spacing
    name = subject['image'].path.name
    export_to_nii(ensemble_predictions[i].numpy().astype(np.int16), f'./results/{name}', spacing, affine)

Saved NIfTI file to ./results/IBSR_02.nii.gz
Saved NIfTI file to ./results/IBSR_10.nii.gz
Saved NIfTI file to ./results/IBSR_15.nii.gz
