In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
from tqdm import tqdm

module_path = os.path.abspath(os.path.join('../..'))

if module_path not in sys.path:
    sys.path.append(module_path)

def wl_to_lh(window, level):
    low = level - window / 2
    high = level + window / 2
    return low,high

def display_image(img, phys_size=None, window=None, level=None, existing_ax=None):

    if window is None:
        window = np.max(img) - np.min(img)

    if level is None:
        level = window / 2 + np.min(img)

    low,high = wl_to_lh(window,level)

    if existing_ax is None:
        # Display the orthogonal slices
        fig, axes = plt.subplots(figsize=(14, 8))
    else:
        axes = existing_ax

    axes.imshow(img, clim=(low, high), extent= None if phys_size is None else (0, phys_size[0], phys_size[1], 0), cmap='gray')

    if existing_ax is None:
        plt.show()
        
def print_stats(arr):
        print(np.mean(arr),', ',np.std(arr))
        print(np.min(arr), '-', np.max(arr))
        print(arr.shape)

In [None]:
from nnood.utils.default_configuration import get_default_configuration

def prepare_test_trainer(network_type, dset_name, task, network_trainer_type, fold):

    plans_file, output_folder_name, dataset_directory, stage, trainer_class, task_class =\
        get_default_configuration(network_type, dset_name, task, network_trainer_type, silent=True)
    
    trainer = trainer_class(plans_file, fold, task_class, output_folder=output_folder_name,
                            dataset_directory=dataset_directory, stage=stage, unpack_data=True,
                            deterministic=False, fp16=True, load_dataset_ram=False)
    
    trainer.no_print = True
    
    # Need to set training to get datasets loaded
    trainer.initialize(training=True)
    trainer.load_final_checkpoint(train=True)
    trainer.network.eval()
    trainer.track_auroc = trainer.track_metrics = trainer.track_ap = True
    
    return trainer

In [None]:
curr_trainer = prepare_test_trainer('fullres', 'chestXray14_PA_male', 'FPI', 'nnOODTrainerDS', 0)

In [None]:
import torch

def run_test_batch(trnr):
    
    with torch.no_grad():
        trnr.run_iteration(trnr.val_gen, False, True)

In [None]:
curr_trainer.track_ap = True
curr_trainer.trac_auroc = True

for _ in tqdm(range(100)):
    run_test_batch(curr_trainer)
    
curr_trainer.finish_online_evaluation()

In [None]:
all_results_auroc = {}
all_results_ap = {}

for dset in ['chestXray14_PA_male', 'chestXray14_PA_female']:
    all_results_auroc[dset] = {}
    all_results_ap[dset] = {}
    print('Dataset', dset)
    
    for t in ['FPI', 'CutPaste', 'PII', 'NSA', 'NSAMixed']:
        print('Task', t)
        
        all_results_auroc[dset][t] = {'all': []}
        all_results_ap[dset][t] = {'all': []}
        for i in range(5):
            tmp_trainer = prepare_test_trainer('fullres', dset, t, 'nnOODTrainerDS', i)
            
            for _ in tqdm(range(40), desc=f'Fold {i}'):
                run_test_batch(tmp_trainer)
            
            fold_res = tmp_trainer.finish_online_evaluation()
            all_results_auroc[dset][t]['all'].append(fold_res['AUROC'])
            all_results_ap[dset][t]['all'].append(fold_res['AP'])
        
        all_results_auroc[dset][t]['avg'] = np.mean(all_results_auroc[dset][t]['all'])
        all_results_auroc[dset][t]['std'] = np.std(all_results_auroc[dset][t]['all'])
        
        all_results_ap[dset][t]['avg'] = np.mean(all_results_ap[dset][t]['all'])
        all_results_ap[dset][t]['std'] = np.std(all_results_ap[dset][t]['all'])
        
        print('Average AUROC', all_results_auroc[dset][t]['avg'])
        print('Average AP', all_results_ap[dset][t]['avg'])
        print()
            
            

In [None]:
all_results_ap

In [None]:
from nnood.utils.file_operations import save_json, load_json

save_json(all_results_auroc, 'trainer_auroc_resultsAP09.json')
save_json(all_results_ap, 'trainer_ap_resultsAP09.json')

In [None]:

all_results_auroc_old = load_json('trainer_auroc_results.json')
all_results_ap_old = load_json('trainer_ap_results.json')

In [None]:
for d_set in all_results_ap.keys():
    print(d_set, '\n')
    for t in all_results_ap[d_set].keys():
        print(t)
        print('Old: ', all_results_ap_old[d_set][t]['avg'])
        print('New: ', all_results_ap[d_set][t]['avg'])
        print()

In [None]:
all_results_ap

In [None]:
all_results_auroc