In [None]:
# Standard library imports
import sys
import os
import time
import pickle
import pathlib

# Third party imports
import yaml
import numpy as np
import matplotlib.pyplot as plt
import subprocess
import torch.cuda as cuda
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable

# own scripts
from data import Data
from trainer import Trainer
from auxiliary_functions import load_config
from loss import GNLL
from map_functions import export_phantom_maps, import_phantom_maps, mask_phantom_noise_fractions, get_discrete_colors

# Eval Phantom

In [None]:
# MANUAL DEFINITIONS

# build maps or import maps?
new_maps = True
# save figures?
save_figures = True

# define filenames
config_name = 'net_config_10_continuous_data'  # without file ending
# config_name = 'net_config_2_grid_search_36'  # without file ending
phantom_filename = '2021-02-01_thesis_phantom_rb1_01_18.p'  # .p pickle file

In [None]:
# AUTO DEFINITIONS AND LOADING

map_filename = phantom_filename.replace('.p', '_' + config_name + '_maps.p')
# load config file from path
config = load_config(os.path.join('outputs', config_name, '{}.yaml'.format(config_name)))
# crate data object
data = Data()
# load phantom data
data.load_phantom(os.path.join('data', 'phantom', phantom_filename), config)
# create trainings class create NN
trainer = Trainer(data)

# load latest NN model
trainer.load_net(os.path.join('outputs', config_name, '{}_best_model_save.pt'.format(config_name)))

In [None]:
# evaluate numbers
trainer.eval_wasabi()

In [None]:
# SAVE maps to pickle if new_maps is True, else LOAD
if new_maps:
    maps = export_phantom_maps(config, trainer, data, os.path.join('data', 'phantom', 'maps', map_filename))
else:
    maps = import_phantom_maps(os.path.join('data', 'phantom', 'maps', map_filename))

In [None]:
# relevant parameters from b0_shift, rel_b1, t1, t2
params = [k for k in maps.keys() if not 'uncert' in k]

In [None]:
# calculate difference maps: reference - NN map
diffs = {p: data.raw_data['phantom'][p] - maps[p] for p in params}

In [None]:
# PLOT references, NN maps, differences and uncertainties per parameter

%matplotlib qt
from mpl_toolkits.axes_grid1 import ImageGrid

n_par = data.n_tgt_params
n_col = 4

fig = plt.figure(figsize=(19,13), facecolor='white')
axs = ImageGrid(fig, 211,
                nrows_ncols = (n_par, n_col),
                 direction="row",
                 axes_pad = 0.12,
                 add_all=True,
                 label_mode = "all",
                 share_all = False,
                 cbar_location="right",
                 cbar_mode="edge",
                 cbar_size="10%",
                 cbar_pad=0.08
                )
fontsize = 10.5
for i, param in enumerate(params):
    if param == 'b0_shift':
        name = 'b0_shift', '${\Delta}$B$_0$'
    elif param == 'b1_inhom':
        param, name = 'b1_inhom', 'rel. B$_1$'
    elif param == 't1':
        param, name = 't1', 'T$_1$'
    elif param == 't2':
        param, name = 't2', 'T$_2$'

    ref_map = data.raw_data['phantom'][param]
    nn_map = maps[param]
    nn_uncert = maps[param + '_uncert']
    diff = diffs[param]
    if param == 'b1_inhom':
        ref_map = ref_map * 100
        nn_map = nn_map * 100
        nn_uncert = nn_uncert * 100
        diff = diff * 100

    ref_ = axs[i*n_col+0].imshow(ref_map)
    axs[i*n_col+0].cax.colorbar(ref_)

    clims = ref_.get_clim()

    nn_ = axs[i*n_col+1].imshow(nn_map, clim=clims)
    axs[i*n_col+1].cax.colorbar(nn_)

    diff_ = axs[i*n_col+2].imshow(np.abs(diff)*10, clim=clims)
    axs[i*n_col+2].cax.colorbar(diff_)

    uncert_ = axs[i*n_col+3].imshow(nn_uncert*10, clim=clims)
    axs[i*n_col+3].cax.colorbar(uncert_)

axs[0].set_ylabel('${\Delta}$B$_0$ [ppm]')
axs[4].set_ylabel('rel. B$_1$ [%]')
axs[8].set_ylabel('T$_1$ [s]')
for i in [0, 4, 8]:
    axs[i].yaxis.label.set_size(fontsize=fontsize)
axs[0].set_title('Reference map', size=fontsize)
axs[1].set_title('NN map', size=fontsize)
axs[2].set_title('Difference (absx10)', size=fontsize)
axs[3].set_title('Uncertainty (x10)', size=fontsize)

axs.cbar_axes[0].set_yticks([-0.2, 0.0, 0.2])

for ax in axs:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.spines['bottom'].set_color('w')
    ax.spines['top'].set_color('w')
    ax.spines['right'].set_color('w')
    ax.spines['left'].set_color('w')

if save_figures:
    plt.suptitle(map_filename.replace('.p', ''))
    plt.savefig(os.path.join('data', 'phantom', 'maps', 'figs', map_filename.replace('.p', '.png')), bbox_inches='tight', facecolor='white')
plt.show()

In [None]:
# STATISTICAL CALCULATIONS

# mask noise and fractions with zeros
masked_diffs = {k: mask_phantom_noise_fractions(v, data) for k, v in diffs.items()}
masked_diffs_nonzero = {k: v[np.nonzero(v)] for k, v in masked_diffs.items()}

# calculate simple statistics
mins = {k: np.min(v) for k, v in masked_diffs_nonzero.items()}
maxs = {k: np.max(v) for k, v in masked_diffs_nonzero.items()}
means = {k: np.mean(v) for k, v in masked_diffs_nonzero.items()}
medians = {k: np.median(v) for k, v in masked_diffs_nonzero.items()}
quantiles_25 = {k: np.quantile(v, 0.25) for k, v in masked_diffs_nonzero.items()}
quantiles_75 = {k: np.quantile(v, 0.75) for k, v in masked_diffs_nonzero.items()}
#calculate locations
min_locs = {k: np.where(v == np.min(masked_diffs_nonzero[k])) for k, v in masked_diffs.items()}
max_locs = {k: np.where(v == np.max(masked_diffs_nonzero[k])) for k, v in masked_diffs.items()}
lowest_locs = {k: np.where(mask_phantom_noise_fractions(v, data) < quantiles_25[k]) for k, v in diffs.items()}
highest_locs = {k: np.where(mask_phantom_noise_fractions(v, data) > quantiles_75[k]) for k, v in diffs.items()}
other_locs = {k: np.where(np.logical_and(mask_phantom_noise_fractions(v, data) < quantiles_75[k], mask_phantom_noise_fractions(v, data) > quantiles_25[k])) for k, v in diffs.items()}


In [None]:
# PLOT DIFFERENCE HISTOGRAMS

colors = get_discrete_colors(len(params))
fig, ax = plt.subplots(len(params))
for i in range(len(params)):
    ax[i].hist(masked_diffs_nonzero[params[i]], bins='auto', color=colors[i])  # arguments are passed to np.histogram
    ax[i].set_xlabel(params[i] + ' diff (min: ' + str(round(mins[params[i]], 3)) + ', mean: ' + str(round(means[params[i]], 3)) + ', max: ' + str(round(maxs[params[i]], 3)) + ')')
    ax[i].set_ylabel('n')
# ax[2].set_xlim((-0.15, 0.15))
plt.show()
if save_figures:
    plt.suptitle(map_filename.replace('.p', ''))
    plt.tight_layout()
    plt.savefig(os.path.join('data', 'phantom', 'figs', map_filename.replace('.p', '_diffs_hist.png')), facecolor='white')

In [None]:
# PLOT SCATTER

fig, axs = plt.subplots(len(params))
ticks = {'b0_shift': [-0.02, -0.01, 0, 0.01, 0.02],
         'b1_inhom': [-0.015, -0.01, -0.005, 0, 0.005, 0.01, 0.015],
         't1': [-0.3, -0.15, 0, 0.15, 0.3],
         't2': [-3, -1.5, 0, 1.5]
         }
lims = {'b0_shift': (-0.025, 0.025),
         'b1_inhom': (-0.015, 0.015),
         't1': (-0.4, 0.4),
         't2': (-3, 1.5)
         }
colors = get_discrete_colors(3)
for i, p in enumerate(params):
    axs[i].scatter(data.raw_data['phantom'][p][lowest_locs[p]], diffs[p][lowest_locs[p]], color=colors[0], s=2)
    axs[i].scatter(data.raw_data['phantom'][p][highest_locs[p]], diffs[p][highest_locs[p]], color=colors[2], s=2)
    axs[i].scatter(data.raw_data['phantom'][p][other_locs[p]], diffs[p][other_locs[p]], color=colors[1], s=2)
    xmin = np.min([data.raw_data['phantom'][p][np.nonzero(data.raw_data['phantom'][p])]])*0.95
    xmax = np.max([data.raw_data['phantom'][p][np.nonzero(data.raw_data['phantom'][p])]])*1.05
    axs[i].set_xlim((xmin, xmax))
    axs[i].set_ylim(lims[p])
    axs[i].set_ylabel('diffs')
    axs[i].set_xlabel(p)
    axs[i].set_yticks(ticks[p])
    # axs[i].grid('both')
plt.show()
if save_figures:
    plt.suptitle(map_filename.replace('.p', ''))
    plt.tight_layout()
    plt.savefig(os.path.join('data', 'phantom', 'figs', map_filename.replace('.p', '_diffs_scatter_xlim.png')), facecolor='white')