In [1]:
from DeepMRI import DeepMRI
import SegAN_IO_arch as seganio
import dataset_helpers as dh
from NegotiationTools import NegTools
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

In [2]:
negtools = NegTools() 

In [3]:
MODELS_PATH = 'models/'
checkpoint_basename = 'best_dice_score_'
output_labels = 1
model_input_size = 160
seed=1234567890

dataset = {
           'training':'brats2015_training_crop_mri',
           'validation':'brats2015_validation_crop_mri',
           'testing':'brats2015_testing_crop_mri'
          }

full_network = {'path': 'Segan_IO_TF2_brats_ALL', 'epoch': 1123, 'modalities': ["MR_T1", "MR_T1c", "MR_T2", "MR_Flair"]}

single_modalities = \
    {
    'MR_T1': {'path': 'Segan_IO_TF2_brats_on_T1', 'epoch': 861, 'modalities': ["MR_T1"], "performance":0.5304, "weight": 0.25},
    'MR_T1c': {'path': 'Segan_IO_TF2_brats_on_T1c', 'epoch': 1011, 'modalities': ["MR_T1c"], "performance":0.5822, "weight": 0.25},
    'MR_T2': {'path': 'Segan_IO_TF2_brats_on_T2', 'epoch': 182, 'modalities': ["MR_T2"], "performance":0.7439, "weight": 0.5},
    'MR_Flair': {'path': 'Segan_IO_TF2_brats_on_FLAIR', 'epoch': 168, 'modalities': ["MR_Flair"], "performance":0.8040, "weight": 0.5}
    }

transfer_from_flair = \
    {
    'MR_T1': {'path': 'Transfer_Brats_Flair_to_T1_freeze_all', 'epoch': 1122, 'modalities': ["MR_T1"], "performance":0.5701, "weight": 0.25},
    'MR_T1c': {'path': 'Transfer_Brats_Flair_to_T1c_freeze_all', 'epoch': 751, 'modalities': ["MR_T1c"], "performance":0.5463, "weight": 0.25},
    'MR_T2': {'path': 'Transfer_Brats_Flair_to_T2_freeze_all', 'epoch': 649, 'modalities': ["MR_T2"], "performance":0.7946, "weight": 0.5}
    }

test_data = dh.load_dataset(dataset['testing'],
                    mri_type=full_network['modalities'],
                    ground_truth_column_name='OT',
                    clip_labels_to=output_labels,
                    center_crop=[model_input_size, model_input_size, len(full_network['modalities'])],
                    batch_size=64,
                    prefetch_buffer=1,
                    infinite=False,
                    cache=False,
                    shuffle=False
                    )

In [4]:
# Loading Single Modality Models
full_network['checkpoint'] = glob.glob(MODELS_PATH + full_network['path'] + '/' + checkpoint_basename + str(full_network['epoch']) + '*.index')[0].replace('.index', '')
full_network['model'] = DeepMRI(batch_size=64, size=model_input_size, mri_channels=len(full_network['modalities']), output_labels=output_labels, model_name=full_network['path'])
full_network['model'].build_model(load_model=full_network['checkpoint'], seed=1234567890, arch=seganio)
full_network['model'].mri_types = full_network['modalities'] # workaround for using log_step without loading a dataset

for modality in full_network['modalities']:
    single_modalities[modality]['checkpoint'] = glob.glob(MODELS_PATH + single_modalities[modality]['path'] + '/' + checkpoint_basename + str(single_modalities[modality]['epoch']) + '*.index')[0].replace('.index', '')
    single_modalities[modality]['model'] = DeepMRI(batch_size=64, size=model_input_size, mri_channels=len(single_modalities[modality]['modalities']), output_labels=output_labels, model_name=single_modalities[modality]['path'])
    single_modalities[modality]['model'].build_model(load_model=single_modalities[modality]['checkpoint'], seed=1234567890, arch=seganio)

Using architecture: SegAN_IO_arch
Loading models/Segan_IO_TF2_brats_ALL/best_dice_score_1123-51
Resuming model from: models/Segan_IO_TF2_brats_ALL/best_dice_score_1123-51, next epoch: 1124
Using architecture: SegAN_IO_arch
Loading models/Segan_IO_TF2_brats_on_T1/best_dice_score_861-45
Resuming model from: models/Segan_IO_TF2_brats_on_T1/best_dice_score_861-45, next epoch: 862
Using architecture: SegAN_IO_arch
Loading models/Segan_IO_TF2_brats_on_T1c/best_dice_score_1011-48
Resuming model from: models/Segan_IO_TF2_brats_on_T1c/best_dice_score_1011-48, next epoch: 1012
Using architecture: SegAN_IO_arch
Loading models/Segan_IO_TF2_brats_on_T2/best_dice_score_182-25
Resuming model from: models/Segan_IO_TF2_brats_on_T2/best_dice_score_182-25, next epoch: 183
Using architecture: SegAN_IO_arch
Loading models/Segan_IO_TF2_brats_on_FLAIR/best_dice_score_168-29
Resuming model from: models/Segan_IO_TF2_brats_on_FLAIR/best_dice_score_168-29, next epoch: 169


In [5]:
eval_logger = pd.DataFrame()
print("Warning, this is going to work only if the output is binary and single-label")
for i, row in enumerate(test_data):
    preds = full_network['model'].generator(row['mri']).numpy()
    metrics = full_network['model'].compute_metrics(row['seg'], preds)
    eval_logger = full_network['model'].log_step(eval_logger, row, None, metrics)
eval_logger['loss_g'] = 0
eval_logger['loss_d'] = 0
mv_results_single  = full_network['model'].log_epoch(eval_logger, 'testing', 0, None)
mv_results_single.to_csv('results/brats2015_performances_all_modalities.csv')



In [6]:
# Majority voting between single modality networks
eval_logger = pd.DataFrame()
print("Warning, this is going to work only if the output is binary and single-label")
for i, row in enumerate(test_data):
    predictions = list()
    for m, modality in enumerate(full_network['modalities']):
        if modality == 'MR_Flair':
            continue
        predictions.append(single_modalities[modality]['model'].generator(row['mri'][:,:,:,m][..., tf.newaxis]).numpy())
        
    predictions = np.stack(predictions, axis=1)
    # Proposals has to be "one hot" for the last channel        
    proposals = np.stack([1.-predictions[...,0], predictions[...,0]], axis=-1)
    
    mv_predictions = np.stack([negtools.compute_majority_voting(mri_slice, 'maximum').astype(np.float32) for mri_slice in proposals], axis=0)[..., 1, np.newaxis]
    metrics = full_network['model'].compute_metrics(row['seg'], mv_predictions)
    eval_logger = full_network['model'].log_step(eval_logger, row, None, metrics)
eval_logger['loss_g'] = 0
eval_logger['loss_d'] = 0
mv_results_single  = full_network['model'].log_epoch(eval_logger, 'testing', 0, None)
mv_results_single.to_csv('results/brats2015_majority_voting_performances_single_modalities.csv')



In [5]:
# Majority voting between single modality networks
eval_logger = pd.DataFrame()
print("Warning, this is going to work only if the output is binary and single-label")
for i, row in enumerate(test_data):
    predictions = list()
    weights = list()
    for m, modality in enumerate(full_network['modalities']):
#         if modality == 'MR_Flair':
#             continue
        predictions.append(single_modalities[modality]['model'].generator(row['mri'][:,:,:,m][..., tf.newaxis]).numpy())
        weights.append(single_modalities[modality]['weight'])
        
    predictions = np.stack(predictions, axis=1)
    weights = np.asarray(weights)
    weights = weights/np.sum(weights)
    
    # Proposals has to be "one hot" for the last channel        
    proposals = np.stack([1.-predictions[...,0], predictions[...,0]], axis=-1)
    weights = np.tile(weights[:, None, None, None], [1, proposals.shape[2], proposals.shape[3], 2])
    
    mv_predictions = np.stack([negtools.compute_majority_voting(mri_slice, 'maximum', weights=weights).astype(np.float32) for mri_slice in proposals], axis=0)[..., 1, np.newaxis]
    metrics = full_network['model'].compute_metrics(row['seg'], mv_predictions)
    eval_logger = full_network['model'].log_step(eval_logger, row, None, metrics)
eval_logger['loss_g'] = 0
eval_logger['loss_d'] = 0
mv_results_single  = full_network['model'].log_epoch(eval_logger, 'testing', 0, None)
mv_results_single.to_csv('results/brats2015_weighted_majority_voting_performances_single_modalities.csv')



In [8]:
for modality in full_network['modalities']:
    if modality == 'MR_Flair':
        continue
    transfer_from_flair[modality]['checkpoint'] = glob.glob(MODELS_PATH + transfer_from_flair[modality]['path'] + '/' + checkpoint_basename + str(transfer_from_flair[modality]['epoch']) + '*.index')[0].replace('.index', '')


for modality in [m for m in full_network['modalities'] if m != 'MR_Flair']:
    transfer_from_flair[modality]['checkpoint'] = glob.glob(MODELS_PATH + transfer_from_flair[modality]['path'] + '/' + checkpoint_basename + str(transfer_from_flair[modality]['epoch']) + '*.index')[0].replace('.index', '')
    transfer_from_flair[modality]['model'] = DeepMRI(batch_size=64, size=model_input_size, mri_channels=len(transfer_from_flair[modality]['modalities']), output_labels=output_labels, model_name=transfer_from_flair[modality]['path'])
    transfer_from_flair[modality]['model'].build_model(load_model=transfer_from_flair[modality]['checkpoint'], seed=1234567890, arch=seganio)

Using architecture: SegAN_IO_arch
Loading models/Transfer_Brats_Flair_to_T1_freeze_all/best_dice_score_1122-84
Resuming model from: models/Transfer_Brats_Flair_to_T1_freeze_all/best_dice_score_1122-84, next epoch: 1123
Using architecture: SegAN_IO_arch
Loading models/Transfer_Brats_Flair_to_T1c_freeze_all/best_dice_score_751-56
Resuming model from: models/Transfer_Brats_Flair_to_T1c_freeze_all/best_dice_score_751-56, next epoch: 752
Using architecture: SegAN_IO_arch
Loading models/Transfer_Brats_Flair_to_T2_freeze_all/best_dice_score_649-59
Resuming model from: models/Transfer_Brats_Flair_to_T2_freeze_all/best_dice_score_649-59, next epoch: 650


In [11]:
eval_logger = pd.DataFrame()
print("Warning, this is going to work only if the output is binary and single-label")
for i, row in enumerate(test_data):
    predictions = list()
    for m, modality in enumerate([m for m in full_network['modalities'] if m != 'MR_Flair']):
        predictions.append(transfer_from_flair[modality]['model'].generator(row['mri'][:,:,:,m][..., tf.newaxis]).numpy())
        
    predictions = np.stack(predictions, axis=1)
    # Proposals has to be "one hot" for the last channel        
    proposals = np.stack([1.-predictions[...,0], predictions[...,0]], axis=-1)
    mv_predictions = np.stack([negtools.compute_majority_voting(mri_slice, 'treshold').astype(np.float32) for mri_slice in proposals], axis=0)[..., 1, np.newaxis]
    metrics = full_network['model'].compute_metrics(row['seg'], mv_predictions)
    eval_logger = full_network['model'].log_step(eval_logger, row, None, metrics)
eval_logger['loss_g'] = 0
eval_logger['loss_d'] = 0
mv_results_single  = full_network['model'].log_epoch(eval_logger, 'testing', 0, None)
mv_results_single.to_csv('results/brats2015_majority_voting_performances_transfer.csv')



In [12]:
eval_logger = pd.DataFrame()
print("Warning, this is going to work only if the output is binary and single-label")
for i, row in enumerate(test_data):
    predictions = list()
    weights = list()
    for m, modality in enumerate([m for m in full_network['modalities'] if m != 'MR_Flair']):
        predictions.append(transfer_from_flair[modality]['model'].generator(row['mri'][:,:,:,m][..., tf.newaxis]).numpy())
        weights.append(transfer_from_flair[modality]['weight'])
        
    predictions = np.stack(predictions, axis=1)
    weights = np.asarray(weights)
    weights = weights/np.sum(weights)
    
    # Proposals has to be "one hot" for the last channel        
    proposals = np.stack([1.-predictions[...,0], predictions[...,0]], axis=-1)
    weights = np.tile(weights[:, None, None, None], [1, proposals.shape[2], proposals.shape[3], 2])
    mv_predictions = np.stack([negtools.compute_majority_voting(mri_slice, 'maximum', weights=weights).astype(np.float32) for mri_slice in proposals], axis=0)[..., 1, np.newaxis]
    
    metrics = full_network['model'].compute_metrics(row['seg'], mv_predictions)
    eval_logger = full_network['model'].log_step(eval_logger, row, None, metrics)
eval_logger['loss_g'] = 0
eval_logger['loss_d'] = 0
mv_results_single  = full_network['model'].log_epoch(eval_logger, 'testing', 0, None)
mv_results_single.to_csv('results/brats2015_weighted_majority_voting_performances_transfer.csv')

