In [None]:
import os
import pickle as pkl
import matplotlib.pyplot as plt
from matplotlib.colors import from_levels_and_colors, to_rgb
import numpy as np
import pandas as pd

In [None]:
model_scale_list = [128, 256, 512, 1024]
model_dir_list = [
    r'D:\Users\mickey\Data\models\pyt\lfads\gw_250_fl80u100\lfads_ecog\cenc0_cont0_fact64_genc128_gene128_glat128_nch41_seqlen50_ulat0_orion-varstd',
    r'D:\Users\mickey\Data\models\pyt\lfads\gw_250_fl80u100\lfads_ecog\cenc0_cont0_fact64_genc256_gene256_glat256_nch41_seqlen50_ulat0_orion-varstd',
    r'D:\Users\mickey\Data\models\pyt\lfads\gw_250_fl80u100\lfads_ecog\cenc0_cont0_fact64_genc512_gene512_glat512_nch41_seqlen50_ulat0_orion-varstd',
    r'D:\Users\mickey\Data\models\pyt\lfads\gw_250_fl80u100\lfads_ecog\cenc0_cont0_fact64_genc1024_gene1024_glat1024_nch41_seqlen50_ulat0_orion-varstd-redo',
]
psd_data_file_list = [os.path.join(s,r'figs\psd_data_dict.pkl') for s in model_dir_list]
performance_table_file_list = [os.path.join(s,r'performance_table.csv') for s in model_dir_list]
for psd_data_file in psd_data_file_list:
    assert os.path.exists(psd_data_file), f'{psd_data_file} not found.'
for perf_table_file in performance_table_file_list:
    assert os.path.exists(perf_table_file), f'{perf_table_file} not found.'

In [None]:
def get_model_parameters_from_perf_table_path(perf_table_file):
    file_parts = perf_table_file.split(os.path.sep)
    model_dir_str = file_parts[-2]
    return model_dir_str

def read_and_concat_csv(table_file_list):
    table_list = []
    for file_idx, table_file in enumerate(table_file_list):
        model_dir_str = get_model_parameters_from_perf_table_path(table_file)
        table_row = pd.read_csv(table_file)
        table_row['model_dir_name'] = model_dir_str
        table_list.append(table_row)
    table_cat = pd.concat(table_list)
    return table_cat

perf_table_all = read_and_concat_csv(performance_table_file_list)

In [None]:
fig, ax = plt.subplots(1,1,dpi=150)
cmap = np.linspace(to_rgb('gainsboro'),to_rgb('tab:orange'),len(psd_data_file_list))
for idx, psd_data_file in enumerate(psd_data_file_list):
    with open(psd_data_file,'rb') as f:
        psd_dict = pkl.load(f)
    ax.plot(psd_dict['f_psd'],10*np.log10(psd_dict['recon_psd_mean']),label=str(model_scale_list[idx]),color=cmap[idx,:])
ax.plot(psd_dict['f_psd'],10*np.log10(psd_dict['data_psd_mean']),color='tab:blue',label='data')
ax.legend(loc=0,title='Unit Count')
ax.set_xlim(0,100)
ax.set_xlabel('freq. (Hz)')
ax.set_ylabel('PSD (dB)')
ax.set_title('ECoG reconstruction, varying unit count')

In [None]:
quals_report_directory_path = r'G:\My Drive\publications\Quals paper'
fig.savefig(os.path.join(quals_report_directory_path,'psd_vary_unit_count_fl80u100.png'))
fig.savefig(os.path.join(quals_report_directory_path,'psd_vary_unit_count_fl80u100.svg'))

In [None]:
fig, ax = plt.subplots(1,1,dpi=150)
# MSE
ax.errorbar(model_scale_list,perf_table_all['mse_mean'],
            [perf_table_all['mse_97.5ci']-perf_table_all['mse_mean'],
             perf_table_all['mse_mean']-perf_table_all['mse_2.5ci']],
             capsize=2.5, color='tab:red', marker='o',
             label='MSE')
# corr
ax.errorbar(model_scale_list,perf_table_all['corr_mean'],
            [perf_table_all['corr_97.5ci']-perf_table_all['corr_mean'],
             perf_table_all['corr_mean']-perf_table_all['corr_2.5ci']],
             capsize=2.5, color='tab:purple', marker='o',
             label='Corr.')
ax.legend(loc=0)
ax.set_xlabel('Unit Count')
ax.set_ylabel('Metric Measurement')
ax.set_title('Reconstruction Error and Corr. v. Unit Count')

In [None]:
quals_report_directory_path = r'G:\My Drive\publications\Quals paper'
fig.savefig(os.path.join(quals_report_directory_path,'perf_vary_unit_count_fl80u100.png'))
fig.savefig(os.path.join(quals_report_directory_path,'perf_vary_unit_count_fl80u100.svg'))