In [1]:
import pandas as pd
import glob
from DeepMRI import DeepMRI
import SegAN_arch as segan_original
import SegAN_IO_arch as seganio
import tensorflow as tf
import numpy as np

models_path = 'models/'
checkpoint_basename = 'best_dice_score_'
batch_size = 64
def load_model(model_name, model_epoch, architecture, dataset,channels, seed=1234567890 ):
    model_checkpoint = glob.glob(models_path + model_name + '/' + checkpoint_basename + str(model_epoch) + '*.index')[0].replace('.index', '')
    model = DeepMRI(batch_size=batch_size, size=160, mri_channels=channels, model_name=model_name)
    model.build_model(load_model=model_checkpoint, seed=1234567890, arch=architecture)
    return model



In [2]:
t1 = load_model("Transfer_Brats_Flair_to_T1_freeze_all", model_epoch=1122, architecture=seganio, dataset='brats', channels=1)
t1c = load_model("Segan_IO_TF2_brats_on_T1c", model_epoch=1011, architecture=seganio, dataset='brats', channels=1)
t2 = load_model("Transfer_Brats_Flair_to_T2_freeze_all", model_epoch=649, architecture=seganio, dataset='brats', channels=1)
flair = load_model("Segan_IO_TF2_brats_on_FLAIR", model_epoch=168, architecture=seganio, dataset='brats', channels=1)
full = load_model("Segan_IO_TF2_brats_ALL", model_epoch=1123, architecture=seganio, dataset='brats', channels=4)

Using architecture: SegAN_IO_arch
Loaded history from models/Transfer_Brats_Flair_to_T1_freeze_all/log_train.csv
Loaded history from models/Transfer_Brats_Flair_to_T1_freeze_all/log_valid.csv
Loading models/Transfer_Brats_Flair_to_T1_freeze_all/best_dice_score_1122-84
Loaded model from: models/Transfer_Brats_Flair_to_T1_freeze_all/best_dice_score_1122-84, next epoch: 1397
Using architecture: SegAN_IO_arch
Loaded history from models/Segan_IO_TF2_brats_on_T1c/log_train.csv
Loaded history from models/Segan_IO_TF2_brats_on_T1c/log_valid.csv
Loading models/Segan_IO_TF2_brats_on_T1c/best_dice_score_1011-48
Loaded model from: models/Segan_IO_TF2_brats_on_T1c/best_dice_score_1011-48, next epoch: 1439
Using architecture: SegAN_IO_arch
Loaded history from models/Transfer_Brats_Flair_to_T2_freeze_all/log_train.csv
Loaded history from models/Transfer_Brats_Flair_to_T2_freeze_all/log_valid.csv
Loading models/Transfer_Brats_Flair_to_T2_freeze_all/best_dice_score_649-59
Loaded model from: models/Tran

In [3]:
full.load_dataset(dataset='brats', mri_types=['MR_T1', 'MR_T1c', 'MR_T2', 'MR_Flair'])

Loading dataset brats with modalities MR_T1,MR_T1c,MR_T2,MR_Flair
Done.


In [26]:
import matplotlib.pyplot as plt

def compute_majority_voting(proposals_np, binary_strategy):
        '''Calculates the Majority voting between the given agent proposals. binary_strategy can be either 'threshold' (only pixels > 0.5 are considered positive) or 'maximum' (only the maximum amongst the prediction is considered)  '''
        def tie_braking(majority, axis = -1):
                random = np.random.uniform(size=majority.shape)
                random_masked = majority.astype(np.uint8)*random
                solution =  np.equal(random_masked, np.max(random_masked, axis=-1)[..., np.newaxis])
                # Fixing the cases in which all elements are "False" in the input vector
                coords = np.where(np.all(np.logical_not(majority), axis=-1))
                for coord in zip(*(coords)):
                    solution[coord] = np.full_like(majority[coord], fill_value=False)
                return solution
        if binary_strategy == 'treshold':
            binary_predictions = np.greater(proposals_np, 0.5)
        elif binary_strategy == 'maximum':
            binary_predictions = np.equal(proposals_np, proposals_np.max(axis=-1)[...,np.newaxis])
        else:
            raise IllegalArgumentError()
        votes = np.count_nonzero(binary_predictions, axis=0)
        majority = np.equal(votes, votes.max(axis=-1, keepdims=True))
        majority = tie_braking(majority)
            
        return majority
results = pd.DataFrame()

for i, row in enumerate(full.test_dataset()):
    x_t1 = row['mri'][..., 0][..., tf.newaxis]
    x_t1c = row['mri'][..., 1][..., tf.newaxis]
    x_t2 = row['mri'][..., 2][..., tf.newaxis]
    x_flair = row['mri'][..., 3][..., tf.newaxis]
    x_full = row['mri']
    gt = row['seg'].numpy()
    
    preds_full = full.generator(x_full)
    preds_t1 = t1.generator(x_t1)
    preds_t1c = t1c.generator(x_t1c)
    preds_t2 = t2.generator(x_t2)
    preds_flair = flair.generator(x_flair)
    
    # Convert in [batch, h, w, 1, label] format
    t1_softmaxed = np.stack([1-preds_t1.numpy(), preds_t1.numpy()], axis=-1)
    t1c_softmaxed = np.stack([1-preds_t1c.numpy(), preds_t1c.numpy()], axis=-1)
    t2_softmaxed = np.stack([1-preds_t2.numpy(), preds_t2.numpy()], axis=-1)
    flair_softmaxed =np.stack([1-preds_flair.numpy(), preds_flair.numpy()], axis=-1)
    
    stacked = np.stack([t1_softmaxed, t1c_softmaxed, t2_softmaxed, flair_softmaxed]).squeeze()
    stacked_not1 = np.stack([t1c_softmaxed, t2_softmaxed, flair_softmaxed]).squeeze()
    stacked_not1c = np.stack([t1_softmaxed, t2_softmaxed, flair_softmaxed]).squeeze()
    stacked_not2 = np.stack([t1_softmaxed, t1c_softmaxed, flair_softmaxed]).squeeze()
    stacked_noflair = np.stack([t1_softmaxed, t1c_softmaxed, t2_softmaxed]).squeeze()
    
    mv = compute_majority_voting(stacked, binary_strategy='treshold')
    mv_not1 = compute_majority_voting(stacked_not1, binary_strategy='treshold')
    mv_not1c = compute_majority_voting(stacked_not1c, binary_strategy='treshold')
    mv_not2 = compute_majority_voting(stacked_not2, binary_strategy='treshold')
    mv_noflair = compute_majority_voting(stacked_noflair, binary_strategy='treshold')
    
    for method, predictions in {'Full Training': preds_full.numpy(), 
                                'Majority Voting': mv[..., 1, np.newaxis],
                                'Majority Voting No T1': mv_not1[..., 1, np.newaxis],
                                'Majority Voting No T1c': mv_not1c[..., 1, np.newaxis],
                                'Majority Voting No T2': mv_not2[..., 1, np.newaxis],
                                'Majority Voting No Flair': mv_noflair[..., 1, np.newaxis]}.items():
        eval_metrics = full.compute_metrics(row['seg'], predictions.astype(np.float32), 0, 0)
        temp_results = pd.DataFrame(np.stack([met.numpy().squeeze().astype(np.float32) for met in eval_metrics], axis=1),
                    columns = ['loss_g','loss_d','sensitivity','specificity','false_positive_rate','precision','dice_score','balanced_accuracy', 'smooth_dice_loss', 'mae'])
        temp_results = temp_results.reset_index()
        temp_results['index'] = temp_results['index'] + i*batch_size
        temp_results = temp_results.rename(columns={'index':'slice'})
        temp_results['method'] = method
        results = results.append(temp_results, ignore_index=True)  
    
    
        
        
        
        
    

In [31]:
results.to_csv('results/majority_voting_per_slice.csv')

In [34]:
means = pd.DataFrame()
for g, group in results.groupby(['method']):
    m = group.mean(axis=0).to_frame().transpose()
    for c in group.columns:
        if c not in m.columns:
            m[c] = group.iloc[0][c]
    m = m.drop(columns=['slice', 'loss_d', 'loss_g'])
    means = means.append(m, ignore_index=True)
means.to_csv('results/majority_voting_means.csv')

In [35]:
means

Unnamed: 0,sensitivity,specificity,false_positive_rate,precision,dice_score,balanced_accuracy,smooth_dice_loss,mae,method
0,0.820794,0.996146,0.003854,0.914162,0.806111,0.90847,0.194799,0.010591,Full Training
1,0.759633,0.995234,0.004766,0.88233,0.751134,0.877433,0.248866,0.015047,Majority Voting
2,0.742947,0.993624,0.006376,0.873698,0.723704,0.868286,0.276296,0.01735,Majority Voting No Flair
3,0.769401,0.996756,0.003244,0.938799,0.785393,0.883079,0.214607,0.012703,Majority Voting No T1
4,0.789803,0.996223,0.003777,0.921499,0.791359,0.893013,0.208641,0.012736,Majority Voting No T1c
5,0.736526,0.994359,0.005641,0.890648,0.725454,0.865442,0.274546,0.017332,Majority Voting No T2


In [None]:
for i in range(preds_t1.shape[0]):
        plt.figure(figsize=(16, 9))
        plt.subplot(1, 5, 1)
        plt.imshow(preds_t1[i, ..., 0])
        plt.subplot(1, 5, 2)
        plt.imshow(preds_t1c[i, ..., 0])
        plt.subplot(1, 5, 3)
        plt.imshow(preds_t2[i, ..., 0])
        plt.subplot(1, 5, 4)
        plt.imshow(preds_flair[i, ..., 0])
        plt.subplot(1, 5, 5)
        plt.imshow(preds_full[i, ..., 0])
        
        
        plt.figure(figsize=(16, 9))
        plt.subplot(1, 3, 1)
        plt.imshow(row['seg'].numpy()[i, ..., 0])
        plt.subplot(1, 3, 2)
        plt.imshow(preds_full[i, ..., 0])
        plt.subplot(1, 3, 3)
        plt.imshow(mv_treshold[i, ..., 1])

In [None]:
evaluate([1, 1, 1], [1, 1, 1])