# Collect results
Here we perform MRI/US fusion using the trained VoxelMorph models and collect the results

In [1]:
from notebooks.setup import test_generator, model, config, latest_checkpoint
from notebooks.utils import dice_coeff
import itertools
import numpy as np

Instructions for updating:
Use fn_output_signature instead
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.


vxm info: mutual information loss is experimental


Prepare storing function

In [2]:
from tqdm.notebook import tqdm
from skimage import measure

def create_and_store_results(model, generator, filename):
    results = {
        "prostate_dice": [], "target_dice": [], "target_error": [],
        "prostate_dice_def": [], "target_dice_def": [], "target_error_def": [],
        "prostate_dice_pre": [], "target_dice_pre": [], "target_error_pre": []
    }
    for idx in tqdm(range(len(generator))):
        test_input, test_output, mr_targets, us_targets, _ = generator(idx)

        test_pred = model.predict(test_input)
        prostate_dice_val = dice_coeff(test_pred[2], test_output[2])
        results["prostate_dice"].append(prostate_dice_val)
        prostate_dice_val = dice_coeff(test_pred[2], test_input[2])
        results["prostate_dice_def"].append(prostate_dice_val)
        prostate_dice_val = dice_coeff(test_input[2], test_output[2])
        results["prostate_dice_pre"].append(prostate_dice_val)

        for i, (mr_target, us_target) in enumerate(itertools.zip_longest(mr_targets, us_targets)):
            if mr_target is None or us_target is None:
                print(f" /!\  test data [{idx}] has unpaired targets [{i}]")
                continue

            for select_target in range(min(mr_target.shape[-1], us_target.shape[-1])):
                test_input_target = mr_target[np.newaxis, ..., [select_target]]
                test_output_target = us_target[np.newaxis, ..., [select_target]]
                test_pred_target = model.apply_transform(test_input[0], test_input[1], test_input_target)

                # in "error" we use  *0.5 because images have an isotropic spacing of 0.5mm

                # Dice between target in MR_def and US
                target_dice_val = dice_coeff(test_pred_target, test_output_target)
                target_error_val = np.linalg.norm((measure.centroid(test_pred_target) - measure.centroid(test_output_target))) * 0.5
                results["target_dice"].append(target_dice_val)
                results["target_error"].append(target_error_val)

                # Dice between target in MR and MR_def
                target_dice_val = dice_coeff(test_pred_target, test_input_target)
                target_error_val = np.linalg.norm((measure.centroid(test_pred_target) - measure.centroid(test_input_target))) * 0.5
                results["target_dice_def"].append(target_dice_val)
                results["target_error_def"].append(target_error_val)

                # Dice between target in MR and US
                target_dice_val = dice_coeff(test_input_target, test_output_target)
                target_error_val = np.linalg.norm((measure.centroid(test_input_target) - measure.centroid(test_output_target))) * 0.5
                results["target_dice_pre"].append(target_dice_val)
                results["target_error_pre"].append(target_error_val)

    # convert everything to ndarray
    for k in results.keys():
        results[k] = np.array(results[k])

    # store results
    try:
        np.savez_compressed(filename, results)
    except Exception:
        print("Error saving the results!")

    return results

### Run the evaluation

Run for a single model

In [4]:
# instantiate the network and load the model
model.load_weights(latest_checkpoint)

filename = f"../../results/results_{config['lambda_param']}_{config['gamma_param']}"
create_and_store_results(model, test_generator, filename)

Instructions for updating:
Use fn_output_signature instead
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.


vxm info: mutual information loss is experimental


  0%|          | 0/96 [00:00<?, ?it/s]

 /!\  test data [48] has unpaired targets [0]
 /!\  test data [66] has unpaired targets [0]


{'prostate_dice': array([0.91869912, 0.75918459, 0.89592761, 0.89619003, 0.85775538,
        0.89246712, 0.86910642, 0.89455009, 0.90213038, 0.87024531,
        0.89112001, 0.91365095, 0.81571669, 0.90591561, 0.89551297,
        0.88877972, 0.8509952 , 0.74769889, 0.90416269, 0.71163294,
        0.88402234, 0.89212798, 0.86995622, 0.86565515, 0.88999425,
        0.9020369 , 0.87806071, 0.8983652 , 0.89945231, 0.89123902,
        0.89315729, 0.92085957, 0.91208281, 0.88400813, 0.87932055,
        0.8594467 , 0.84390872, 0.91545001, 0.90845353, 0.87402146,
        0.88481258, 0.92294359, 0.84043093, 0.86798186, 0.92201922,
        0.88982705, 0.87607058, 0.92469503, 0.88897921, 0.90330394,
        0.90463912, 0.89403541, 0.9039655 , 0.86564326, 0.90648353,
        0.88911593, 0.90111953, 0.88543659, 0.90071152, 0.89482004,
        0.90556911, 0.9144566 , 0.89932506, 0.89435959, 0.86180107,
        0.91710287, 0.85359133, 0.82356925, 0.85689804, 0.87366227,
        0.9013515 , 0.87947305,

Run for all the models

In [4]:
import tensorflow as tf
from pathlib import Path
from notebooks.setup import size
from notebooks.utils import prepare_model

model = prepare_model(inshape=size, sim_param=0, lambda_param=0, gamma_param=0)
for checkpoints_dir in Path(f"../../models").iterdir():
    info = checkpoints_dir.name.split("_")
    if len(info) > 3:
        print("special model, ignoring")
        continue
    _, lambda_param, gamma_param = info
    latest_checkpoint = tf.train.latest_checkpoint(checkpoints_dir / "checkpoints")
    model.load_weights(latest_checkpoint)

    filename = f"../../results/results_{lambda_param}_{gamma_param}"
    create_and_store_results(model, test_generator, filename)



vxm info: mutual information loss is experimental


  0%|          | 0/96 [00:00<?, ?it/s]



  / M[(0,) * image.ndim])  # weighted sum of all points


 /!\  test data [48] has unpaired targets [0]
 /!\  test data [66] has unpaired targets [0]


  0%|          | 0/96 [00:00<?, ?it/s]

 /!\  test data [48] has unpaired targets [0]
 /!\  test data [66] has unpaired targets [0]


  0%|          | 0/96 [00:00<?, ?it/s]

 /!\  test data [48] has unpaired targets [0]
 /!\  test data [66] has unpaired targets [0]


  0%|          | 0/96 [00:00<?, ?it/s]

 /!\  test data [48] has unpaired targets [0]
 /!\  test data [66] has unpaired targets [0]


  0%|          | 0/96 [00:00<?, ?it/s]

 /!\  test data [48] has unpaired targets [0]
 /!\  test data [66] has unpaired targets [0]


  0%|          | 0/96 [00:00<?, ?it/s]

 /!\  test data [48] has unpaired targets [0]
 /!\  test data [66] has unpaired targets [0]


  0%|          | 0/96 [00:00<?, ?it/s]

 /!\  test data [48] has unpaired targets [0]
 /!\  test data [66] has unpaired targets [0]
special model, ignoring
special model, ignoring


  0%|          | 0/96 [00:00<?, ?it/s]

 /!\  test data [48] has unpaired targets [0]
 /!\  test data [66] has unpaired targets [0]


  0%|          | 0/96 [00:00<?, ?it/s]

 /!\  test data [48] has unpaired targets [0]
 /!\  test data [66] has unpaired targets [0]


  0%|          | 0/96 [00:00<?, ?it/s]

 /!\  test data [48] has unpaired targets [0]
 /!\  test data [66] has unpaired targets [0]
special model, ignoring
special model, ignoring


  0%|          | 0/96 [00:00<?, ?it/s]

 /!\  test data [48] has unpaired targets [0]
 /!\  test data [66] has unpaired targets [0]


  0%|          | 0/96 [00:00<?, ?it/s]

 /!\  test data [48] has unpaired targets [0]
 /!\  test data [66] has unpaired targets [0]


  0%|          | 0/96 [00:00<?, ?it/s]

 /!\  test data [48] has unpaired targets [0]
 /!\  test data [66] has unpaired targets [0]


  0%|          | 0/96 [00:00<?, ?it/s]

 /!\  test data [48] has unpaired targets [0]
 /!\  test data [66] has unpaired targets [0]


  0%|          | 0/96 [00:00<?, ?it/s]

 /!\  test data [48] has unpaired targets [0]
 /!\  test data [66] has unpaired targets [0]


  0%|          | 0/96 [00:00<?, ?it/s]

 /!\  test data [48] has unpaired targets [0]
 /!\  test data [66] has unpaired targets [0]
special model, ignoring


  0%|          | 0/96 [00:00<?, ?it/s]

 /!\  test data [48] has unpaired targets [0]
 /!\  test data [66] has unpaired targets [0]


  0%|          | 0/96 [00:00<?, ?it/s]

 /!\  test data [48] has unpaired targets [0]
 /!\  test data [66] has unpaired targets [0]


  0%|          | 0/96 [00:00<?, ?it/s]

 /!\  test data [48] has unpaired targets [0]
 /!\  test data [66] has unpaired targets [0]
