In [1]:
# Move to working directory
%cd ..

/root/huy/BrainSegmentation


In [None]:
import os
from tqdm import tqdm
import numpy as np
import dagshub
dagshub.init(repo_owner='huytrnq', repo_name='BrainSegmentation', mlflow=True)

import torch
import torchio as tio
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torchio")

from utils.dataset import BrainMRIDataset, BrainMRISliceDataset
from utils.predict import Predictor
from utils.metric import dice_score_3d
from utils.utils import export_to_nii

In [4]:
ROOT_DIR = './Data'
BATCH_SIZE = 16
NUM_CLASSES = 4
NUM_WORKERS = 16
DEVICE = 'mps' if torch.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

## Patch-Based

In [6]:
## Transforms
val_transform = tio.Compose([
    tio.RescaleIntensity((0, 1)),
    tio.ZNormalization(),
])
## Dataset
val_dataset = BrainMRIDataset(os.path.join(ROOT_DIR, 'val'), transform=val_transform)

In [10]:
# Load model using MLflow
predictor_patchs = []
run_ids = ['94c24db36be94ebe947cdaf160c07409', '1bbaa5b686a1493bbe6a6fd83bfed272', 'dacc0d9816cc4ec5859f3e227a8bba9c', '5cabc56f7b374afd8b1c8ae12a190312']
patch_sizes = [64, 64, 64, 128]
for run_id, patch_size  in zip(run_ids, patch_sizes):
    predictor_patch = Predictor(mlflow_model_uri=f"runs:/{run_id}/model", device=DEVICE, patch_size=patch_size)
    predictor_patchs.append(predictor_patch)

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

In [11]:
predictions = []
masks = torch.stack([subject['mask'][tio.DATA] for subject in val_dataset], dim=0)

overlaps = [32, 32, 32, 64]
for predictor_patch, overlap in zip(predictor_patchs, overlaps):
    patch_predictions = []
    for subject in tqdm(val_dataset):
        prediction = predictor_patch.predict_patches(subject, batch_size=BATCH_SIZE, overlap=overlap, proba=True)
        patch_predictions.append(prediction)

    # Stack all patch_predictions
    patch_predictions = torch.stack(patch_predictions, dim=0)
    predictions.append(patch_predictions)
    dice = dice_score_3d(torch.argmax(patch_predictions, dim=1), masks.squeeze(1), num_classes=NUM_CLASSES)
    print(f"Dice score: {dice}")

100%|██████████| 5/5 [02:34<00:00, 30.93s/it]


Dice score: {0: 0.9976367950439453, 1: 0.9245912432670593, 2: 0.9475903511047363, 3: 0.943122386932373}


100%|██████████| 5/5 [02:35<00:00, 31.13s/it]


Dice score: {0: 0.9976142644882202, 1: 0.923819363117218, 2: 0.9470298886299133, 3: 0.942984938621521}


100%|██████████| 5/5 [02:33<00:00, 30.66s/it]


Dice score: {0: 0.997552752494812, 1: 0.9232271313667297, 2: 0.9455758333206177, 3: 0.941463828086853}


100%|██████████| 5/5 [00:15<00:00,  3.10s/it]


Dice score: {0: 0.9977170825004578, 1: 0.9252330660820007, 2: 0.9494802355766296, 3: 0.9438269734382629}


In [12]:
### Ensemble
ensemble_predictions = torch.stack(predictions, dim=0)
ensemble_predictions = ensemble_predictions.mean(dim=0)
dice = dice_score_3d(torch.argmax(ensemble_predictions, dim=1), masks.squeeze(1), num_classes=NUM_CLASSES)
print(f"Ensemble Dice score: {dice}")

Ensemble Dice score: {0: 0.9977818727493286, 1: 0.928473949432373, 2: 0.9510732889175415, 3: 0.9468572735786438}


In [13]:
np.mean(list(dice.values())[1:])

0.9421348373095194

## Full Volume Based

In [19]:
## Transforms
val_transform = tio.Compose([
    tio.RescaleIntensity((0, 1)),
    tio.ZNormalization(),
])
## Dataset
val_dataset = BrainMRIDataset(os.path.join(ROOT_DIR, 'val'), transform=val_transform)
val_loader = tio.SubjectsLoader(val_dataset, batch_size=1, shuffle=False, num_workers=NUM_WORKERS)

In [5]:
# Load model using MLflow
predictor_full = Predictor(mlflow_model_uri="runs:/1c98c526ea884b768e491a03985c8f22/model", device=DEVICE)

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

In [25]:
## Predict
predictions_full = predictor_full.predict_full_volume(val_loader, proba=False)

In [23]:
dice_score_3d(torch.argmax(predictions_full.squeeze(1), dim=1), masks.squeeze(1), num_classes=NUM_CLASSES)

{0: 0.9966910481452942,
 1: 0.8851995468139648,
 2: 0.923554539680481,
 3: 0.9096390008926392}

## Ensemble

In [19]:
## Transforms
test_transform = tio.Compose([
    tio.RescaleIntensity((0, 1)),
    tio.ZNormalization(),
])
## Dataset
test_dataset = BrainMRIDataset(os.path.join(ROOT_DIR, 'test'), transform=test_transform)

In [20]:
predictions = []

overlaps = [32, 32, 32, 64]
for predictor_patch, overlap in zip(predictor_patchs, overlaps):
    patch_predictions = []
    for subject in tqdm(test_dataset):
        prediction = predictor_patch.predict_patches(subject, batch_size=BATCH_SIZE, overlap=overlap, proba=True)
        patch_predictions.append(prediction)

    # Stack all patch_predictions
    patch_predictions = torch.stack(patch_predictions, dim=0)
    predictions.append(patch_predictions)

### Ensemble
ensemble_predictions = torch.stack(predictions, dim=0)
ensemble_predictions = ensemble_predictions.mean(dim=0)
ensemble_predictions = torch.argmax(ensemble_predictions, dim=1)

  0%|          | 0/3 [00:00<?, ?it/s]

100%|██████████| 3/3 [00:59<00:00, 19.93s/it]
100%|██████████| 3/3 [00:57<00:00, 19.08s/it]
100%|██████████| 3/3 [01:02<00:00, 20.69s/it]
100%|██████████| 3/3 [00:06<00:00,  2.29s/it]


In [21]:
for i, subject in enumerate(test_dataset):
    affine = subject['image'].affine
    spacing = subject['image'].spacing
    name = subject['image'].path.name
    export_to_nii(ensemble_predictions[i].numpy().astype(np.int16), f'./results/{name}', spacing, affine)

Saved NIfTI file to ./results/IBSR_02.nii.gz
Saved NIfTI file to ./results/IBSR_10.nii.gz
Saved NIfTI file to ./results/IBSR_15.nii.gz
