In [None]:
import os
import pickle as pkl
import matplotlib.pyplot as plt
from matplotlib.colors import from_levels_and_colors, to_rgb
import numpy as np
import pandas as pd

In [None]:
model_scale_list = ['MSE','MSE+FDL','FDL']
model_dir_list = [
    r'D:\Users\mickey\Data\models\pyt\lfads\gw_250\lfads_ecog\cenc0_cont0_fact64_genc1024_gene1024_glat1024_nch42_seqlen50_ulat0_orion-varstd',
    r'D:\Users\mickey\Data\models\pyt\lfads\gw_250\lfads_ecog\cenc0_cont0_fact64_genc1024_gene1024_glat1024_nch42_seqlen50_ulat0_orion-varstd-fdl',
    r'D:\Users\mickey\Data\models\pyt\lfads\gw_250\lfads_ecog\cenc0_cont0_fact64_genc1024_gene1024_glat1024_nch42_seqlen50_ulat0_orion-varstd-fdlonly',
]
psd_data_file_list = [os.path.join(s,r'figs\psd_data_dict.pkl') for s in model_dir_list]
performance_table_file_list = [os.path.join(s,r'performance_table.csv') for s in model_dir_list]
for psd_data_file in psd_data_file_list:
    assert os.path.exists(psd_data_file), f'{psd_data_file} not found.'
for perf_table_file in performance_table_file_list:
    assert os.path.exists(perf_table_file), f'{perf_table_file} not found.'

In [None]:
def get_model_parameters_from_perf_table_path(perf_table_file):
    file_parts = perf_table_file.split(os.path.sep)
    model_dir_str = file_parts[-2]
    return model_dir_str

def read_and_concat_csv(table_file_list):
    table_list = []
    for file_idx, table_file in enumerate(table_file_list):
        model_dir_str = get_model_parameters_from_perf_table_path(table_file)
        table_row = pd.read_csv(table_file)
        table_row['model_dir_name'] = model_dir_str
        table_list.append(table_row)
    table_cat = pd.concat(table_list)
    return table_cat

perf_table_all = read_and_concat_csv(performance_table_file_list)

In [None]:
fig, ax = plt.subplots(1,1,dpi=150)
linestyle_list = ['-','-.',':']
for idx, psd_data_file in enumerate(psd_data_file_list):
    with open(psd_data_file,'rb') as f:
        psd_dict = pkl.load(f)
    ax.plot(psd_dict['f_psd'],10*np.log10(psd_dict['recon_psd_mean']),label=str(model_scale_list[idx]),color='tab:orange',linestyle=linestyle_list[idx])
ax.plot(psd_dict['f_psd'],10*np.log10(psd_dict['data_psd_mean']),color='tab:blue',label='data')
ax.legend(loc=0,title='Loss Function')
ax.set_xlim(0,100)
ax.set_xlabel('freq. (Hz)')
ax.set_ylabel('PSD (dB)')
ax.set_title('ECoG reconstruction, Frequency Domain Loss')

In [None]:
quals_report_directory_path = r'G:\My Drive\publications\Quals paper'
fig.savefig(os.path.join(quals_report_directory_path,'psd_fdl_comparison.png'))
fig.savefig(os.path.join(quals_report_directory_path,'psd_fdl_comparison.svg'))

In [None]:
fig, ax = plt.subplots(1,1,dpi=150)
x_axis = np.arange(len(model_dir_list))
# mse
ax.errorbar(x_axis,perf_table_all['mse_mean'],
            [perf_table_all['mse_97.5ci']-perf_table_all['mse_mean'],
             perf_table_all['mse_mean']-perf_table_all['mse_2.5ci']],
             capsize=2.5, color='tab:red', marker='o',
             label='MSE')
# corr
ax.errorbar(x_axis,perf_table_all['corr_mean'],
            [perf_table_all['corr_97.5ci']-perf_table_all['corr_mean'],
             perf_table_all['corr_mean']-perf_table_all['corr_2.5ci']],
             capsize=2.5, color='tab:purple', marker='o',
             label='Corr.')
ax.legend(loc=0)
ax.set_xticks(x_axis)
ax.set_xticklabels(model_scale_list)
ax.set_xlabel('Loss Function')
ax.set_ylabel('Metric Measurement')
ax.set_title('Reconstruction Error and Corr. v. Sequence Length')

In [None]:
quals_report_directory_path = r'G:\My Drive\publications\Quals paper'
fig.savefig(os.path.join(quals_report_directory_path,'perf_fdl_comparison.png'))
fig.savefig(os.path.join(quals_report_directory_path,'perf_fdl_comparison.svg'))