In [1]:
from DeepMRI import DeepMRI
import SegAN_IO_arch as seganio
import dataset_helpers as dh
from NegotiationTools import NegTools
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

In [2]:
negtools = NegTools() 

In [3]:
MODELS_PATH = 'models/'
checkpoint_basename = 'best_dice_score_0_'
output_labels = 1
model_input_size = 160
seed=1234567890

dataset = {
           'training':'brats2019_training_crop_mri',
           'validation':'brats2019_validation_crop_mri',
           'testing':'brats2019_testing_crop_mri'
          }

full_network = {'path': 'Segan_IO_TF2_brats2019_ALL', 'epoch': 133, 'modalities': ["t1", "t1ce", "t2", "flair"]}

single_modalities = \
    {
    't1': {'path': 'Segan_IO_TF2_brats2019_T1', 'epoch': 124, 'modalities': ["t1"], "performance":0.54, "weight": 0.25},
    't1ce': {'path': 'Segan_IO_TF2_brats2019_T1ce', 'epoch': 597, 'modalities': ["t1ce"], "performance":0.59, "weight": 0.25},
    't2': {'path': 'Segan_IO_TF2_brats2019_T2', 'epoch': 136, 'modalities': ["t2"], "performance":0.67, "weight": 0.5},
    'flair': {'path': 'Segan_IO_TF2_brats2019_FLAIR', 'epoch': 89, 'modalities': ["flair"], "performance":0.76, "weight": 0.5}
    }

transfer_from_flair = \
    {
    't1': {'path': 'Transfer_Brats2019_Flair_to_t1', 'epoch': 144, 'modalities': ["t1"], "performance":0.5, "weight": 0.25}, 
    't1ce': {'path': 'Transfer_Brats2019_Flair_to_t1ce_freeze_all', 'epoch': 300, 'modalities': ["t1ce"], "performance":0.56, "weight": 0.25},
    't2': {'path': 'Transfer_Brats2019_Flair_to_t2_freeze_all', 'epoch': 191, 'modalities': ["t2"], "performance":0.72, "weight": 0.5}
    }

test_data = dh.load_dataset(dataset['validation'],
                    mri_type=full_network['modalities'],
                    ground_truth_column_name='seg',
                    clip_labels_to=output_labels,
                    center_crop=[model_input_size, model_input_size, len(full_network['modalities'])],
                    batch_size=64,
                    prefetch_buffer=1,
                    infinite=False,
                    cache=False,
                    shuffle=False
                    )

In [4]:
# Loading Single Modality Models
full_network['checkpoint'] = glob.glob(MODELS_PATH + full_network['path'] + '/' + checkpoint_basename + str(full_network['epoch']) + '*.index')[0].replace('.index', '')
full_network['model'] = DeepMRI(batch_size=64, size=model_input_size, mri_channels=len(full_network['modalities']), output_labels=output_labels, model_name=full_network['path'])
full_network['model'].build_model(load_model=full_network['checkpoint'], seed=1234567890, arch=seganio)
full_network['model'].mri_types = full_network['modalities'] # workaround for using log_step without loading a dataset

for modality in full_network['modalities']:
    single_modalities[modality]['checkpoint'] = glob.glob(MODELS_PATH + single_modalities[modality]['path'] + '/' + checkpoint_basename + str(single_modalities[modality]['epoch']) + '*.index')[0].replace('.index', '')
    single_modalities[modality]['model'] = DeepMRI(batch_size=64, size=model_input_size, mri_channels=len(single_modalities[modality]['modalities']), output_labels=output_labels, model_name=single_modalities[modality]['path'])
    single_modalities[modality]['model'].build_model(load_model=single_modalities[modality]['checkpoint'], seed=1234567890, arch=seganio)

Using architecture: SegAN_IO_arch
Loading models/Segan_IO_TF2_brats2019_ALL/best_dice_score_0_133-18
Resuming model from: models/Segan_IO_TF2_brats2019_ALL/best_dice_score_0_133-18, next epoch: 134
Using architecture: SegAN_IO_arch
Loading models/Segan_IO_TF2_brats2019_T1/best_dice_score_0_124-10
Resuming model from: models/Segan_IO_TF2_brats2019_T1/best_dice_score_0_124-10, next epoch: 125
Using architecture: SegAN_IO_arch
Loading models/Segan_IO_TF2_brats2019_T1ce/best_dice_score_0_597-17
Resuming model from: models/Segan_IO_TF2_brats2019_T1ce/best_dice_score_0_597-17, next epoch: 598
Using architecture: SegAN_IO_arch
Loading models/Segan_IO_TF2_brats2019_T2/best_dice_score_0_136-11
Resuming model from: models/Segan_IO_TF2_brats2019_T2/best_dice_score_0_136-11, next epoch: 137
Using architecture: SegAN_IO_arch
Loading models/Segan_IO_TF2_brats2019_FLAIR/best_dice_score_0_89-10
Resuming model from: models/Segan_IO_TF2_brats2019_FLAIR/best_dice_score_0_89-10, next epoch: 90


In [9]:
# All modality network
eval_logger = pd.DataFrame()
print("Warning, this is going to work only if the output is binary and single-label")
for i, row in enumerate(test_data):
    preds = full_network['model'].generator(row['mri']).numpy()
    metrics = full_network['model'].compute_metrics(row['seg'], preds)
    eval_logger = full_network['model'].log_step(eval_logger, row, None, metrics)
eval_logger['loss_g'] = 0
eval_logger['loss_d'] = 0
mv_results_single  = full_network['model'].log_epoch(eval_logger, 'testing', 0, None)
mv_results_single.to_csv('results/brats2019_performances_all_modalities.csv')



In [5]:
# Majority voting between single modality networks
eval_logger = pd.DataFrame()
print("Warning, this is going to work only if the output is binary and single-label")
for i, row in enumerate(test_data):
    predictions = list()
    for m, modality in enumerate(full_network['modalities']):
        if modality == 'flair':
            continue
        predictions.append(single_modalities[modality]['model'].generator(row['mri'][:,:,:,m][..., tf.newaxis]).numpy())
        
    predictions = np.stack(predictions, axis=1)
    # Proposals has to be "one hot" for the last channel        
    proposals = np.stack([1.-predictions[...,0], predictions[...,0]], axis=-1)
    mv_predictions = np.stack([negtools.compute_majority_voting(mri_slice, 'treshold').astype(np.float32) for mri_slice in proposals], axis=0)[..., 1, np.newaxis]
    metrics = full_network['model'].compute_metrics(row['seg'], mv_predictions)
    eval_logger = full_network['model'].log_step(eval_logger, row, None, metrics)
eval_logger['loss_g'] = 0
eval_logger['loss_d'] = 0
mv_results_single  = full_network['model'].log_epoch(eval_logger, 'testing', 0, None)
mv_results_single.to_csv('results/brats2019_majority_voting_performances_single_modalities.csv')



In [5]:
# Weighted Majority voting between single modality networks
eval_logger = pd.DataFrame()
print("Warning, this is going to work only if the output is binary and single-label")
for i, row in enumerate(test_data):
    predictions = list()
    weights = list()
    for m, modality in enumerate(full_network['modalities']):
#         if modality == 'flair':
#             continue
        predictions.append(single_modalities[modality]['model'].generator(row['mri'][:,:,:,m][..., tf.newaxis]).numpy())
        weights.append(single_modalities[modality]['weight'])
        
    predictions = np.stack(predictions, axis=1)
    weights = np.asarray(weights)
    weights = weights/np.sum(weights)
    
    # Proposals has to be "one hot" for the last channel        
    proposals = np.stack([1.-predictions[...,0], predictions[...,0]], axis=-1)
    weights = np.tile(weights[:, None, None, None], [1, proposals.shape[2], proposals.shape[3], 2])
    mv_predictions = np.stack([negtools.compute_majority_voting(mri_slice, 'maximum', weights=weights).astype(np.float32) for mri_slice in proposals], axis=0)[..., 1, np.newaxis]
    
    metrics = full_network['model'].compute_metrics(row['seg'], mv_predictions)
    eval_logger = full_network['model'].log_step(eval_logger, row, None, metrics)
eval_logger['loss_g'] = 0
eval_logger['loss_d'] = 0
mv_results_single  = full_network['model'].log_epoch(eval_logger, 'testing', 0, None)
mv_results_single.to_csv('results/brats2019_weighted_majority_voting_performances_single_modalities.csv')



In [6]:
for modality in full_network['modalities']:
    if modality == 'flair':
        continue
    transfer_from_flair[modality]['checkpoint'] = glob.glob(MODELS_PATH + transfer_from_flair[modality]['path'] + '/' + checkpoint_basename + str(transfer_from_flair[modality]['epoch']) + '*.index')[0].replace('.index', '')


for modality in [m for m in full_network['modalities'] if m != 'flair']:
    transfer_from_flair[modality]['checkpoint'] = glob.glob(MODELS_PATH + transfer_from_flair[modality]['path'] + '/' + checkpoint_basename + str(transfer_from_flair[modality]['epoch']) + '*.index')[0].replace('.index', '')
    transfer_from_flair[modality]['model'] = DeepMRI(batch_size=64, size=model_input_size, mri_channels=len(transfer_from_flair[modality]['modalities']), output_labels=output_labels, model_name=transfer_from_flair[modality]['path'])
    transfer_from_flair[modality]['model'].build_model(load_model=transfer_from_flair[modality]['checkpoint'], seed=1234567890, arch=seganio)

Using architecture: SegAN_IO_arch
Loading models/Transfer_Brats2019_Flair_to_t1/best_dice_score_0_144-25
Resuming model from: models/Transfer_Brats2019_Flair_to_t1/best_dice_score_0_144-25, next epoch: 145
Using architecture: SegAN_IO_arch
Loading models/Transfer_Brats2019_Flair_to_t1ce_freeze_all/best_dice_score_0_300-25
Resuming model from: models/Transfer_Brats2019_Flair_to_t1ce_freeze_all/best_dice_score_0_300-25, next epoch: 301
Using architecture: SegAN_IO_arch
Loading models/Transfer_Brats2019_Flair_to_t2_freeze_all/best_dice_score_0_191-19
Resuming model from: models/Transfer_Brats2019_Flair_to_t2_freeze_all/best_dice_score_0_191-19, next epoch: 192


In [None]:
eval_logger = pd.DataFrame()
print("Warning, this is going to work only if the output is binary and single-label")
for i, row in enumerate(test_data):
    predictions = list()
    for m, modality in enumerate([m for m in full_network['modalities'] if m != 'flair']):
        predictions.append(transfer_from_flair[modality]['model'].generator(row['mri'][:,:,:,m][..., tf.newaxis]).numpy())
        
    predictions = np.stack(predictions, axis=1)
    
    # Proposals has to be "one hot" for the last channel        
    proposals = np.stack([1.-predictions[...,0], predictions[...,0]], axis=-1)
    mv_predictions = np.stack([negtools.compute_majority_voting(mri_slice, 'maximum').astype(np.float32) for mri_slice in proposals], axis=0)[..., 1, np.newaxis]
    
    metrics = full_network['model'].compute_metrics(row['seg'], mv_predictions)
    eval_logger = full_network['model'].log_step(eval_logger, row, None, metrics)
eval_logger['loss_g'] = 0
eval_logger['loss_d'] = 0
mv_results_single  = full_network['model'].log_epoch(eval_logger, 'testing', 0, None)
mv_results_single.to_csv('results/brats2019_majority_voting_performances_transfer.csv')

In [7]:
eval_logger = pd.DataFrame()
print("Warning, this is going to work only if the output is binary and single-label")
for i, row in enumerate(test_data):
    predictions = list()
    weights = list()
    for m, modality in enumerate([m for m in full_network['modalities'] if m != 'flair']):
        predictions.append(transfer_from_flair[modality]['model'].generator(row['mri'][:,:,:,m][..., tf.newaxis]).numpy())
        weights.append(transfer_from_flair[modality]['weight'])
        
    predictions = np.stack(predictions, axis=1)
    weights = np.asarray(weights)
    weights = weights/np.sum(weights)
    
    # Proposals has to be "one hot" for the last channel        
    proposals = np.stack([1.-predictions[...,0], predictions[...,0]], axis=-1)
    weights = np.tile(weights[:, None, None, None], [1, proposals.shape[2], proposals.shape[3], 2])
    mv_predictions = np.stack([negtools.compute_majority_voting(mri_slice, 'maximum', weights=weights).astype(np.float32) for mri_slice in proposals], axis=0)[..., 1, np.newaxis]
    
    metrics = full_network['model'].compute_metrics(row['seg'], mv_predictions)
    eval_logger = full_network['model'].log_step(eval_logger, row, None, metrics)
eval_logger['loss_g'] = 0
eval_logger['loss_d'] = 0
mv_results_single  = full_network['model'].log_epoch(eval_logger, 'testing', 0, None)
mv_results_single.to_csv('results/brats2019_weighted_majority_voting_performances_transfer.csv')

