Comparação de Modelos

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import csv
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
import random
random.seed(5)
from scipy import signal,stats
import os
import seaborn as sns

In [None]:
from models import SpectroViT
from metrics import calculate_shape_score, calculate_mse
from utils_for_evaluation import *

Definições

In [None]:
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')

In [None]:
hop_size = [10,14,18,22,26,30,34]
window_size = 256
window = signal.windows.hann(256,sym = True)

In [None]:
spectrovit_10 = SpectroViT().to(device)
spectrovit_10.load_state_dict(torch.load('../model_hop_10_3500GT/models/model_hop_10_3500GT_best.pt',map_location= 'cpu'))
spectrovit_14 = SpectroViT().to(device)
spectrovit_14.load_state_dict(torch.load('../model_hop_14_3500GT/models/model_hop_14_3500GT_best.pt',map_location= 'cpu'))
spectrovit_18 = SpectroViT().to(device)
spectrovit_18.load_state_dict(torch.load('../model_hop_18_3500GT/models/model_hop_18_3500GT_best.pt',map_location= 'cpu'))
spectrovit_22 = SpectroViT().to(device)
spectrovit_22.load_state_dict(torch.load('../model_hop_22_3500GT/models/model_hop_22_3500GT_best.pt',map_location= 'cpu'))
spectrovit_26 = SpectroViT().to(device)
spectrovit_26.load_state_dict(torch.load('../model_hop_26_3500GT/models/model_hop_26_3500GT_best.pt',map_location= 'cpu'))
spectrovit_30 = SpectroViT().to(device)
spectrovit_30.load_state_dict(torch.load('../model_hop_30_3500GT/models/model_hop_30_3500GT_best.pt',map_location= 'cpu'))
spectrovit_34 = SpectroViT().to(device)
spectrovit_34.load_state_dict(torch.load('../model_hop_34_3500GT/models/model_hop_34_3500GT_best.pt',map_location= 'cpu'))

In [None]:
list_models = [spectrovit_10.eval(), spectrovit_14.eval(), spectrovit_18.eval(), 
                spectrovit_22.eval(), spectrovit_26.eval(),spectrovit_30.eval(),
                spectrovit_34.eval()]
name_model = ['spectrovit_10','spectrovit_14',  'spectrovit_18', 
              'spectrovit_22', 'spectrovit_26','spectrovit_30',
              'spectrovit_34']
path_to_test_data = '../dataset_test_multiplenoise_from_SGT_4000_to_5000.h5'
dataset_list = ['DatasetSpgramSyntheticData','DatasetSpgramSyntheticData','DatasetSpgramSyntheticData',
                'DatasetSpgramSyntheticData','DatasetSpgramSyntheticData','DatasetSpgramSyntheticData',
                'DatasetSpgramSyntheticData']

Inferência: aquisição de métricas objetivas

In [None]:
dict_metrics = get_metrics_for_different_models(path_to_test_data=path_to_test_data,
                                                list_models=list_models,
                                                name_models=name_model,
                                                hop_size=hop_size,
                                                window_size=window_size,
                                                window=window,
                                                device=device,
                                                dataset_list=dataset_list)

Inferência: aquisição de objetos para visualização

In [None]:
predictions, target_concat, ppm_concat,input_spgrams = get_inference_instances(path_to_test_data=path_to_test_data,
                                                                               list_models=list_models,
                                                                               name_models=name_model,
                                                                               hop_size=hop_size,
                                                                               window_size=window_size,
                                                                               window=window,
                                                                               device=device,
                                                                               dataset_list=dataset_list)


In [None]:
x = [predictions[str(int(10))]['spectrovit_'+str(int(10))].flatten(),
     predictions[str(int(14))]['spectrovit_'+str(int(14))].flatten(),
     predictions[str(int(18))]['spectrovit_'+str(int(18))].flatten(),
     target_concat.flatten()]

y = [predictions[str(int(22))]['spectrovit_'+str(int(22))].flatten(),
     predictions[str(int(26))]['spectrovit_'+str(int(26))].flatten(),
     predictions[str(int(30))]['spectrovit_'+str(int(30))].flatten(),
     predictions[str(int(34))]['spectrovit_'+str(int(34))].flatten(),
     target_concat.flatten()]

labels_x = ['vit10', 'vit14', 'vit18','target']
colors_x = ['lightcoral', 'peru', 'greenyellow','black']

labels_y = ['vit22','vit26','vit30','vit34','target']
colors_y = ['turquoise','royalblue','violet','pink','black']
fig, ax = plt.subplots(1,2,figsize=(16,4))
ax[0].hist(x, bins=80, density=True, histtype='step', color=colors_x,label=labels_x)
ax[0].hist(x, bins=80, density=True, histtype='stepfilled',color=colors_x,  alpha=0.25)
ax[1].hist(y, bins=80, density=True, histtype='step', color=colors_y,label=labels_y)
ax[1].hist(y, bins=80, density=True, histtype='stepfilled',color=colors_y,  alpha=0.25)

ax[0].legend(prop={'size': 10})
ax[1].legend(prop={'size': 10})

ax[0].set_title('Histograma: Espectro de GABA')
ax[0].set_xlabel('Valores no Espectro de GABA')
ax[0].set_ylabel('Distribuição dos Valores do Espectro \n no conjunto de Teste')
ax[0].set_yscale('log')

ax[1].set_title('Histograma: Espectro de GABA')
ax[1].set_xlabel('Valores no Espectro de GABA')
ax[1].set_ylabel('Distribuição dos Valores do Espectro \n no conjunto de Teste')
ax[1].set_yscale('log')

In [None]:
from matplotlib.gridspec import GridSpec
fig = plt.figure(figsize=(16, 4))
gs = GridSpec(1, 4, figure=fig)

ax1 = fig.add_subplot(gs[0])  # First plot spans columns 0 and 1
ax2 = fig.add_subplot(gs[1])   # Second plot spans only the last column
ax3 = fig.add_subplot(gs[2:])   # Second plot spans only the last column
position_sup = np.abs(ppm_concat[0,:]-3.2).argmin()
position_inf = np.abs(ppm_concat[0,:]-2.8).argmin()
for i,hop in enumerate(hop_size):
    ax1.plot(ppm_concat[0,position_sup:position_inf], predictions[str(int(hop))][name_model[i]][0,position_sup:position_inf], label='vit'+str(int(hop)))
ax1.plot(ppm_concat[0,position_sup:position_inf],target_concat[0,position_sup:position_inf],label='real')
ax1.set_title('Espectros de GABA \n Pico de GABA')
ax1.set_xlabel('Desloc. Químico (ppm)')
ax1.set_ylabel('Espectro Normalizado')
ax1.set_xlim(3.2,2.8)
ax1.legend(loc='upper right')


position_sup = np.abs(ppm_concat[0,:]-3.9).argmin()
position_inf = np.abs(ppm_concat[0,:]-3.6).argmin()
for i,hop in enumerate(hop_size):
    ax2.plot(ppm_concat[0,position_sup:position_inf], predictions[str(int(hop))][name_models[i]][0,position_sup:position_inf], label='vit'+str(int(hop)))
ax2.plot(ppm_concat[0,position_sup:position_inf],target_concat[0,position_sup:position_inf],label='real')
ax2.set_title('Espectros de GABA \n Pico de Glx')
ax2.set_xlabel('Desloc. Químico (ppm)')
ax2.set_ylabel('Espectro Normalizado')
ax2.set_xlim(3.9,3.6)
ax2.legend(loc='upper right')


for i,hop in enumerate(hop_size):
    ax3.hist(predictions[str(int(hop))]['spectrovit_'+str(int(hop))].flatten(),bins=100,density=True,alpha=0.35,label='vit'+str(int(hop)))
ax3.hist(target_concat.flatten(),bins=100,density=True,alpha=0.35,label='real')
ax3.legend(loc='upper right')
ax3.set_title('Distribuição Espectros \n Reconstruídos e Real')
ax3.set_yscale('log')
ax3.set_xlabel('Valores do Espectro')
ax3.set_ylabel('Densidade')

plt.tight_layout()


Análise:

In [None]:
fig,ax = plt.subplots(2,4,figsize=(16,8))
for i,hop in enumerate(hop_size):
    ax.flat[i].hist(predictions[str(int(hop))]['spectrovit_'+str(int(hop))].flatten(),bins=100,density=True,alpha=0.35,color='b',label='prediction')
    ax.flat[i].hist(target_concat.flatten(),bins=100,density=True,alpha=0.35,color='r',label='target')
    ax.flat[i].set_yscale('log')
    ax.flat[i].set_title('spectrovit_'+str(int(hop)))
    ax.flat[i].legend(loc='upper right')
ax.flat[-1].axis('off')

In [None]:
print('range target:',np.min(target_concat),np.max(target_concat))
for i,hop in enumerate(hop_size):
    print('range predictions spectrovit_'+str(int(hop))+':',np.min(predictions[str(int(hop))]['spectrovit_'+str(int(hop))]),np.max(predictions[str(int(hop))]['spectrovit_'+str(int(hop))]))

In [None]:
for i,hop in enumerate(hop_size):
    plt.hist(predictions[str(int(hop))]['spectrovit_'+str(int(hop))].flatten(),bins=100,density=True,alpha=0.35,label='spectrovit_'+str(int(hop)))
plt.yscale('log')

In [None]:
fig,ax = plt.subplots(2,4,figsize=(16,8))
for i,hop in enumerate(hop_size):
  data_sorted_pred = np.sort(predictions[str(int(hop))]['spectrovit_'+str(int(hop))].flatten())
  data_sorted_tgt = np.sort(target_concat.flatten())
  cumulative_pred = np.arange(1, len(data_sorted_pred) + 1) / len(data_sorted_pred)
  cumulative_tgt = np.arange(1, len(data_sorted_tgt) + 1) / len(data_sorted_tgt)
  ax.flat[i].plot(data_sorted_pred, cumulative_pred, linestyle='solid',color='b',label='predictions')
  ax.flat[i].plot(data_sorted_tgt, cumulative_tgt, linestyle='dotted',color='r',label='target')
  ax.flat[i].set_title('spectrovit_'+str(int(hop)))
  ax.flat[i].legend(loc='lower right')
ax.flat[-1].axis('off')

In [None]:
position_sup = np.abs(ppm_concat[0,:]-4).argmin()
position_inf = np.abs(ppm_concat[0,:]-2.5).argmin()
fig,ax = plt.subplots(2,4,figsize=(16,8))
for i,hop in enumerate(hop_size):
    ax.flat[i].hist(predictions[str(int(hop))]['spectrovit_'+str(int(hop))][:,position_sup:position_inf].flatten(),bins=100,density=True,alpha=0.35,color='b',label='prediction')
    ax.flat[i].hist(target_concat[:,position_sup:position_inf].flatten(),bins=100,density=True,alpha=0.35,color='r',label='target')
    ax.flat[i].set_yscale('log')
    ax.flat[i].set_title('spectrovit_'+str(int(hop)))
    ax.flat[i].legend(loc='upper right')
ax.flat[-1].axis('off')

In [None]:
plt.plot(predictions[str(int(hop))]['spectrovit_'+str(int(hop))][0,position_sup:position_inf])
plt.plot(target_concat[0,position_sup:position_inf])

In [None]:
pred_stats = {'mean':[],'std':[],'median':[],'skew':[],'kurtosis':[]}
tgt_stats = {}
tgt_stats['mean']= np.mean(target_concat.flatten())
tgt_stats['std']= np.std(target_concat.flatten())
tgt_stats['median']= np.median(target_concat.flatten())
tgt_stats['skew']= stats.skew(target_concat.flatten())
tgt_stats['kurtosis']= stats.kurtosis(target_concat.flatten())
for i,hop in enumerate(hop_size):
  pred_stats['mean'].append(np.mean(predictions[str(int(hop))]['spectrovit_'+str(int(hop))].flatten()))
  pred_stats['std'].append(np.std(predictions[str(int(hop))]['spectrovit_'+str(int(hop))].flatten()))
  pred_stats['median'].append(np.median(predictions[str(int(hop))]['spectrovit_'+str(int(hop))].flatten()))
  pred_stats['skew'].append(stats.skew(predictions[str(int(hop))]['spectrovit_'+str(int(hop))].flatten()))
  pred_stats['kurtosis'].append(stats.kurtosis(predictions[str(int(hop))]['spectrovit_'+str(int(hop))].flatten()))
print('|..............Mean.......|...........STD..........|..........Median........|..............Skew.......|..........Kurtosis......|')
str_aux = ''
for i,key in enumerate(list(tgt_stats.keys())):
  if i == 0:
    str_aux=str_aux+'| '+ 'target        ' + '{:.4E}'.format(tgt_stats[key])+ '|'
  else:
    str_aux=str_aux+ 'target        ' + '{:.4E}'.format(tgt_stats[key])+ '|'
print(str_aux)
for j,hop in enumerate(hop_size):
  str_aux = ''
  for i,key in enumerate(list(tgt_stats.keys())):
    if i == 0:
      str_aux=str_aux+'| '+'spectrovit_'+str(int(hop))+' '+ '{:.4E}'.format(pred_stats[key][j])+ '|'
    else:
      str_aux=str_aux+'spectrovit_'+str(int(hop))+' '+ '{:.4E}'.format(pred_stats[key][j])+ '|'
  print(str_aux)

In [None]:
metrics_names = ['LossVal','MSEVal','SNRVal','FWHMVal','ShScVal']
results_for_their_right_hop = {}
for metric in metrics_names:
        results_for_their_right_hop[metric] = [[],[]]
        aux = []
        for idx,model in enumerate(name_model):
                aux.append(dict_metrics[str(int(hop_size[idx]))][metric][idx])
        aux_model, aux_value = order_models(list_metric=aux,model_names=name_model)
        if metric == 'SNRVal' or metric == 'ShScVal':
                results_for_their_right_hop[metric][0] = aux_model[::-1]
                results_for_their_right_hop[metric][1] = np.flip(aux_value)
        else:
                results_for_their_right_hop[metric][0] = aux_model
                results_for_their_right_hop[metric][1] = aux_value
print('|..............LossVal.......|..............MSEVal........|..............SNRVal........|..............FWHMVal.......|..............ShScVal.......|')
for line_idx in range(len(name_model)):
        str_aux = ""
        for metric_idx,metric in enumerate(metrics_names):
                if metric_idx == 0:
                        str_aux = '| '+results_for_their_right_hop[metric][0][line_idx]+' | '+'{:.4E}'.format(results_for_their_right_hop[metric][1][line_idx])+' | ' 
                else:
                        str_aux = str_aux + results_for_their_right_hop[metric][0][line_idx]+' | '+'{:.4E}'.format(results_for_their_right_hop[metric][1][line_idx])+' | ' 
        print(str_aux)

In [None]:
score_model_for_their_right_hop = {}
for model in name_model:
    score_model_for_their_right_hop[model] = 0
for metric in metrics_names:
    if metric != 'LossVal':
        no_repeat = np.unique(results_for_their_right_hop[metric][1])
        max_point = len(no_repeat)
        for model_idx in range(len(results_for_their_right_hop[metric][0])):
            if model_idx == 0:
                add_to_score = max_point
            else:
                if results_for_their_right_hop[metric][1][model_idx] != results_for_their_right_hop[metric][1][model_idx-1]:
                    add_to_score = add_to_score-1
            score_model_for_their_right_hop[results_for_their_right_hop[metric][0][model_idx]] = score_model_for_their_right_hop[results_for_their_right_hop[metric][0][model_idx]] + add_to_score
aux = []
for model in name_model:
    aux.append(score_model_for_their_right_hop[model])
models_scored_for_their_right_hop, models_score_value_for_their_right_hop = order_models(list_metric=aux,model_names=name_model)
for i in reversed(range(len(models_scored_for_their_right_hop))):
    print(models_scored_for_their_right_hop[i] + ': '+str(models_score_value_for_their_right_hop[i]))

In [None]:
fig,ax=plt.subplots(2,2,figsize=(8,6))
for i,metric in enumerate(metrics_names[1:]):
    aux = []
    for hop in hop_size:
        aux.append(dict_metrics[str(int(hop))][metric])
    im=ax.flat[i].imshow(np.array(aux),cmap='magma')
    fig.colorbar(im, ax=ax.flat[i],fraction=0.046, pad=0.04)
    # Set custom ticks and labels
    ax.flat[i].set_xticks(ticks=np.arange(len(hop_size)), labels=['vit10','vit14','vit18','vit22','vit26','vit30','vit34'], rotation=45)
    ax.flat[i].set_yticks(ticks=np.arange(len(hop_size)), labels=['10','14','18','22','26','30','34'] )
    ax.flat[i].set_title(metric)
plt.tight_layout()

In [None]:
for hop in hop_size:
    print('noise estimation on prediction of'+' spectrovit_'+str(int(hop))+':',noise_est(ppm_concat,predictions[str(int(hop))]['spectrovit_'+str(int(hop))]),' GABA max:',get_max_gaba(ppm_concat,predictions[str(int(hop))]['spectrovit_'+str(int(hop))]))

In [None]:
results_for_each_hop = {}
for hop in hop_size:
    results_for_each_hop[str(int(hop))] = {}
    for metric in metrics_names:
        results_for_each_hop[str(int(hop))][metric] = [[],[]]
        aux_model,aux_value = order_models(list_metric=dict_metrics[str(int(hop))][metric],model_names=name_model)
        if metric == 'SNRVal' or metric == 'ShScVal':
            results_for_each_hop[str(int(hop))][metric][0] = aux_model[::-1]
            results_for_each_hop[str(int(hop))][metric][1] = np.flip(aux_value)
        else:
            results_for_each_hop[str(int(hop))][metric][0] = aux_model
            results_for_each_hop[str(int(hop))][metric][1] = aux_value

In [None]:
score_model_combined_hops = {}
for name in name_model:
    score_model_combined_hops[name] = [0,0,0,0,0]
for metric_idx, metric in enumerate(metrics_names):
    for hop in hop_size:
        no_repeat = np.unique(results_for_each_hop[str(int(hop))][metric][1])
        max_point = len(no_repeat)
        for model_idx in range(len(results_for_each_hop[str(int(hop))][metric][0])):
            model_ref = results_for_each_hop[str(int(hop))][metric][0][model_idx]
            if model_idx == 0:
                add_to_score = max_point
            else:
                if results_for_each_hop[str(int(hop))][metric][1][model_idx] != results_for_each_hop[str(int(hop))][metric][1][model_idx-1]: 
                    add_to_score = add_to_score-1                     
            score_model_combined_hops[model_ref][metric_idx] = score_model_combined_hops[model_ref][metric_idx] + add_to_score

to_print_results = {}
for metric_idx, metric in enumerate(metrics_names):
    aux = []
    for model in name_model:
        aux.append(score_model_combined_hops[model][metric_idx])
    aux_model, aux_value = order_models(list_metric=aux,model_names=name_model)
    to_print_results[metric] = [aux_model[::-1],np.flip(aux_value)]
print('|.......LossVal.......|.......MSEVal........|.......SNRVal........|.......FWHMVal.....|.......ShScVal.......|')
for line_idx in range(len(name_model)):
    str_aux = ""
    for metric_idx,metric in enumerate(metrics_names):
            if len(str(to_print_results[metric][1][line_idx])) == 2:
                space = '  | '
            else:
                space = ' | '
            if metric_idx == 0:
                str_aux = '| '+to_print_results[metric][0][line_idx]+' | '+str(to_print_results[metric][1][line_idx])+space 
            else:
                str_aux = str_aux + to_print_results[metric][0][line_idx]+' | '+str(to_print_results[metric][1][line_idx])+space
    print(str_aux)

In [None]:
total_score_models = []
for model in name_model:
    #exclude LossVal
    total_score_models.append(np.sum(np.array(score_model_combined_hops[model][1:])))
aux_model,aux_value = order_models(list_metric=total_score_models,model_names=name_model)
for i in reversed(range(len(aux_model))):
    print(aux_model[i]+': '+str(aux_value[i]))

In [None]:
from matplotlib.gridspec import GridSpec
fig = plt.figure(figsize=(16, 4))
gs = GridSpec(1, 4, figure=fig)

ax1 = fig.add_subplot(gs[0])  # First plot spans columns 0 and 1
ax2 = fig.add_subplot(gs[1])   # Second plot spans only the last column
ax3 = fig.add_subplot(gs[2:])   # Second plot spans only the last column
position_sup = np.abs(ppm_concat[0,:]-3.2).argmin()
position_inf = np.abs(ppm_concat[0,:]-2.8).argmin()
for i,hop in enumerate(hop_size):
    ax1.plot(ppm_concat[0,position_sup:position_inf], predictions[str(int(hop))][name_model[i]][0,position_sup:position_inf], label='vit'+str(int(hop)))
ax1.plot(ppm_concat[0,position_sup:position_inf],target_concat[0,position_sup:position_inf],label='real')
ax1.set_title('Espectros de GABA \n Pico de GABA')
ax1.set_xlabel('Desloc. Químico (ppm)')
ax1.set_ylabel('Espectro Normalizado')
ax1.set_xlim(3.2,2.8)
ax1.legend(loc='upper right')


position_sup = np.abs(ppm_concat[0,:]-3.9).argmin()
position_inf = np.abs(ppm_concat[0,:]-3.6).argmin()
for i,hop in enumerate(hop_size):
    ax2.plot(ppm_concat[0,position_sup:position_inf], predictions[str(int(hop))][name_model[i]][0,position_sup:position_inf], label='vit'+str(int(hop)))
ax2.plot(ppm_concat[0,position_sup:position_inf],target_concat[0,position_sup:position_inf],label='real')
ax2.set_title('Espectros de GABA \n Pico de Glx')
ax2.set_xlabel('Desloc. Químico (ppm)')
ax2.set_ylabel('Espectro Normalizado')
ax2.set_xlim(3.9,3.6)
ax2.legend(loc='upper right')


for i,hop in enumerate(hop_size):
    ax3.hist(predictions[str(int(hop))]['spectrovit_'+str(int(hop))].flatten(),bins=100,density=True,alpha=0.35,label='vit'+str(int(hop)))
ax3.hist(target_concat.flatten(),bins=100,density=True,alpha=0.35,label='real')
ax3.legend(loc='upper right')
ax3.set_title('Distribuição Espectros \n Reconstruídos e Real')
ax3.set_yscale('log')
ax3.set_xlabel('Valores do Espectro')
ax3.set_ylabel('Densidade')

plt.tight_layout()


In [None]:
fig,ax = plt.subplots(1,7,figsize=(16,4))
for i,hop in enumerate(hop_size):
    ax.flat[i].imshow(input_spgrams[str(int(hop))][0,0,:,:],cmap='gray',vmin=-0.04,vmax=0.04,aspect='auto')
    ax.flat[i].set_title('hop size: '+str(int(hop)))
    ax.flat[i].set_ylim(150,50)
    ax.flat[i].set_xlim(0,50)
plt.tight_layout()
#plt.savefig('compare_models_synthdata_spgrams')


In [None]:
colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 
          'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan',
          'gold', 'darkviolet', 'lime', 'dodgerblue']
idx_to_plot=0
hop_idx_to_plot = [0,2,4,6]
fig,ax = plt.subplots(4,5,figsize=(16,20))
for k in range(len(hop_idx_to_plot)):
    ax.flat[5*k].imshow(input_spgrams[str(hop_size[hop_idx_to_plot[k]])][idx_to_plot,0,:,:],cmap='gray',vmin=-0.04,vmax=0.04,aspect='auto')
    ax.flat[5*k].set_title('hop size: '+str(hop_size[hop_idx_to_plot[k]]))
    ax.flat[5*k].set_ylim(150,50)
    ax.flat[5*k].set_xlim(0,50)
    regions = [[2.8,3.2],[3.6,3.9],[1.98,2.05],[10,10.8]]
    for j in range(4):
        position_sup = np.abs(ppm_concat[idx_to_plot,:]-regions[j][-1]).argmin()
        position_inf = np.abs(ppm_concat[idx_to_plot,:]-regions[j][0]).argmin()
        for i in range(len(name_model)):
            ax.flat[5*k+(j+1)].plot(ppm_concat[idx_to_plot,position_sup:position_inf], predictions[str(hop_size[hop_idx_to_plot[k]])][name_model[i]][idx_to_plot,position_sup:position_inf], label=name_model[i][-2:],color=colors[i])
        if j <2:
            ax.flat[5*k+(j+1)].legend(loc='lower right',fontsize=6,ncols=5)
        else:
            ax.flat[5*k+(j+1)].legend(loc='upper right',fontsize=6,ncols=5)
        ax.flat[5*k+(j+1)].plot(ppm_concat[idx_to_plot,position_sup:position_inf],target[idx_to_plot,position_sup:position_inf], color='black')
        ax.flat[5*k+(j+1)].set_xlim(regions[j][-1],regions[j][0])
plt.tight_layout()
#plt.savefig('compare_models_synthdata_reconstructionplots')

In [None]:
height_diff_per_hop = {'GABA':{},'Glx':{},'NAA':{}}
regions = [[2.8,3.2],[3.6,3.9],[1.98,2.05]]
for hop in hop_size:
    for j in range(3):
        height_diff_per_hop[list(height_diff_per_hop.keys())[j]][str(int(hop))] = []
        diff_height = {}
        for model in name_model:
            diff_height[model] = []
        for q in range(ppm_concat.shape[0]):
            position_sup = np.abs(ppm_concat[q,:]-regions[j][-1]).argmin()
            position_inf = np.abs(ppm_concat[q,:]-regions[j][0]).argmin()
            tgt_height = np.max(np.abs(target_concat[q,position_sup:position_inf]))
            for model in name_model:
                diff_height[model].append(np.max(np.abs(predictions[str(int(hop))][model][q,position_sup:position_inf]))-tgt_height)
        for model in name_model:
            aux = np.mean(np.array(diff_height[model]))
            height_diff_per_hop[list(height_diff_per_hop.keys())[j]][str(int(hop))].append(aux)

In [None]:
fig, ax = plt.subplots(1,3,figsize=(16,4),sharex='col')
for idx,key in enumerate(list(height_diff_per_hop.keys())):
    for hop in hop_size:
        ax.flat[idx].scatter(name_model,height_diff_per_hop[key][str(int(hop))],label='hop:'+str(int(hop)))
for i in range(3):
    ax.flat[i].plot(name_model,np.zeros(7),color='black',linestyle='dotted')
    ax.flat[i].set_xticks(ticks=name_model,labels=name_model,rotation=45,fontsize=7)
    ax.flat[i].legend(loc='lower right',fontsize=7)
ax.flat[0].set_title('GABA')
ax.flat[1].set_title('Glx')
ax.flat[2].set_title('NAA')
plt.tight_layout()

In [None]:
print('|..........GABA........|..............Glx.....|..............NAA.....|')
for line_idx in range(len(name_model)):
    str_aux = ""
    for idx,key in enumerate(list(height_diff_per_hop.keys())):
        counter_pos = 0
        counter_neg = 0
        for hop in hop_size:
            if height_diff_per_hop[key][str(int(hop))][line_idx] >= 0:
                counter_pos=counter_pos+1
            else:
                counter_neg=counter_neg+1
        if len(str(counter_neg))==1 and len(str(counter_pos))==1:
            space='   |'
        elif (len(str(counter_neg))==2 and len(str(counter_pos))==1) or (len(str(counter_neg))==1 and len(str(counter_pos))==2):
            space='  |'
        else:
            space=' |'
        if idx == 0:
            str_aux = str_aux + '|'+name_model[line_idx]+' | '+str(counter_pos)+'/'+str(counter_neg)+space
        else:
            str_aux = str_aux + name_model[line_idx]+' | '+str(counter_pos)+'/'+str(counter_neg)+space
    print(str_aux)

In [None]:
print('|..............GABA..........|..............Glx...........|..............NAA...........|')
for line_idx in range(len(name_model)):
    str_aux = ""
    for idx,key in enumerate(list(height_diff_per_hop.keys())):
        if height_diff_per_hop[key][str(int(hop_size[line_idx]))][line_idx] > 0: 
            space = '  |'
        else:
            space = ' |'
        if idx == 0:
            str_aux = str_aux + '|'+name_model[line_idx]+' | '+'{:.4E}'.format(height_diff_per_hop[key][str(int(hop_size[line_idx]))][line_idx])+space
        else:
            str_aux = str_aux + name_model[line_idx]+' | '+'{:.4E}'.format(height_diff_per_hop[key][str(int(hop_size[line_idx]))][line_idx])+space
    print(str_aux)

In [None]:
proximity_shape_score = np.empty((len(hop_size),len(hop_size)))
proximity_MSE = np.empty((len(hop_size),len(hop_size)))
for model_idx,model in enumerate(name_model):
    aux_tgt = predictions[str(int(hop_size[model_idx]))][model]
    for model_idx_aux,model_aux in enumerate(name_model):
        if model_aux != model:
            aux = predictions[str(int(hop_size[model_idx]))][model_aux]
            aux_ss = []
            aux_mse = []
            for q in range(ppm_concat.shape[0]):
                aux_ss.append(calculate_shape_score(x=aux[q,:], y=aux_tgt[q,:],ppm=ppm_concat[q,:]))
                aux_mse.append(calculate_mse(x=aux[q,:], y=aux_tgt[q,:], ppm=ppm_concat[q,:]))
            proximity_shape_score[model_idx,model_idx_aux] = np.mean(np.array(aux_ss))
            proximity_MSE[model_idx,model_idx_aux] = np.mean(np.array(aux_mse))
        else:
            proximity_shape_score[model_idx,model_idx_aux] = 1
            proximity_MSE[model_idx,model_idx_aux] = 0

preds_flatten = {}
for j,name in enumerate(name_model):
    for i,hop in enumerate(hop_size):
        if i == 0:
            preds_flatten[name] = predictions[str(int(hop))][name].flatten()
        else:
            preds_flatten[name] = np.concatenate((preds_flatten[name],predictions[str(int(hop))][name].flatten()))

    if j == 0:
        data = preds_flatten[name]
    else:
        data = np.vstack([data,preds_flatten[name]])
correlation_matrix = np.corrcoef(data)

In [None]:
plt.figure(figsize=(8,4))
sns.heatmap(proximity_shape_score, annot=True, annot_kws={"size": 8}, fmt='.6', cmap='magma', xticklabels=['10', '14', '18', '22','26','30','34'], 
            yticklabels=['10', '14', '18', '22','26','30','34'])
plt.title('Shape Score Between Models Predictions')
plt.tight_layout()

In [None]:
for line_idx in range(proximity_shape_score.shape[0]):
    aux_model,aux_value = order_models(list_metric=proximity_shape_score[line_idx,:],model_names=name_model)
    print(name_model[line_idx]+': '+(aux_model[::-1])[1]+' | '+(aux_model[::-1])[2])

In [None]:
plt.figure(figsize=(10,4))
sns.heatmap(proximity_MSE, annot=True, annot_kws={"size": 8}, fmt='.3E', cmap='magma', xticklabels=['10', '14', '18', '22','26','30','34'], 
            yticklabels=['10', '14', '18', '22','26','30','34'])
plt.title('MSE Between Models Predictions')
plt.tight_layout()

In [None]:
for line_idx in range(proximity_MSE.shape[0]):
    aux_model,aux_value = order_models(list_metric=proximity_MSE[line_idx,:],model_names=name_model)
    print(name_model[line_idx]+': '+(aux_model)[1]+' | '+(aux_model)[2])

In [None]:
plt.figure(figsize=(8,4))
sns.heatmap(correlation_matrix, annot=True, annot_kws={"size": 8}, fmt='.4', cmap='magma', xticklabels=['10', '14', '18', '22','26','30','34'], 
            yticklabels=['10', '14', '18', '22','26','30','34'])
plt.title('Correlation Matrix')
plt.tight_layout()

In [None]:
for line_idx in range(correlation_matrix.shape[0]):
    aux_model,aux_value = order_models(list_metric=correlation_matrix[line_idx,:],model_names=name_model)
    print(name_model[line_idx]+': '+(aux_model[::-1])[1]+' | '+(aux_model[::-1])[2])

In [None]:
correlation_matrix_inputs = np.corrcoef(np.array([input_spgrams['10'].flatten(),
                                                  input_spgrams['14'].flatten(),
                                                  input_spgrams['18'].flatten(),
                                                  input_spgrams['22'].flatten(),
                                                  input_spgrams['26'].flatten(),
                                                  input_spgrams['30'].flatten(),
                                                  input_spgrams['34'].flatten()]))

In [None]:
plt.figure(figsize=(8,4))
sns.heatmap(correlation_matrix_inputs, annot=True, annot_kws={"size": 8}, fmt='.4', cmap='magma', xticklabels=['10', '14', '18', '22','26','30','34'], 
            yticklabels=['10', '14', '18', '22','26','30','34'])
plt.title('Correlation Matrix')
plt.tight_layout()

In [None]:
for line_idx in range(correlation_matrix_inputs.shape[0]):
    aux_model,aux_value = order_models(list_metric=correlation_matrix_inputs[line_idx,:],model_names=name_model)
    print(name_model[line_idx]+': '+(aux_model[::-1])[1]+' | '+(aux_model[::-1])[2])

In [None]:
colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 
          'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan',
          'gold', 'darkviolet', 'lime', 'dodgerblue']
fig,ax = plt.subplots(2,2,figsize=(14,10))
for j,metric in enumerate(metrics_names):
    for i,hop in enumerate(hop_size):
        ax.flat[j].plot(name_model, dict_metrics[str(int(hop))][metric],marker='o',label=str(int(hop)), color=colors[i])
    if j<2:
        ax.flat[j].legend(loc='upper right',fontsize=8, ncols=5)
    else:
        ax.flat[j].legend(loc='lower right',fontsize=8, ncols=5)
    ax.flat[j].set_xticks(ticks=name_model,labels=name_model,rotation=45,fontsize=7)
    ax.flat[j].set_xlabel('models')
    if j == 0: 
        ax.flat[j].set_title('Loss Function: input hop size x model performance')
    elif j == 1: 
        ax.flat[j].set_title('MSE Function: input hop size x model performance')
    elif j == 2: 
        ax.flat[j].set_title('SNR Function: input hop size x model performance')
    else:
        ax.flat[j].set_title('Shape Score: input hop size x model performance')
plt.tight_layout()
#plt.savefig('compare_models_synthdata_inputhop_x_model')

In [None]:
colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 
          'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan',
          'gold', 'darkviolet', 'lime', 'dodgerblue']
fig,ax = plt.subplots(2,2,figsize=(14,10))
for j,metric in enumerate(metrics_names):
    for i,model in enumerate(name_model):
        aux = []
        for hop in hop_size:
            aux.append(dict_metrics[str(int(hop))][metric][i])
        ax.flat[j].plot(hop_size,aux,label='Vit'+model[11:], color=colors[i])
    if j<2:
        ax.flat[j].legend(loc='upper right',fontsize=8, ncols=5)
    else:
        ax.flat[j].legend(loc='lower right',fontsize=8, ncols=5)
    ax.flat[j].set_xlabel('input hop size')
    if j == 0: 
        ax.flat[j].set_title('Loss Function: model performance x input hop size')
    elif j == 1: 
        ax.flat[j].set_title('MSE Function: model performance x input hop size')
    elif j == 2: 
        ax.flat[j].set_title('SNR Function: model performance x input hop size')
    else:
        ax.flat[j].set_title('Shape Score: model performance x input hop size')
plt.tight_layout()
#plt.savefig('compare_models_synthdata_model_x_inputhop')