In [1]:
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
from muse.utils.fitting import masker
import json, os, glob

In [2]:
base_path = '/Users/khcho/MUSE/diffraction/time_series/plot0/'
os.chdir(base_path)
import sys
sys.path.append(base_path)

In [3]:
def read_sdc_log(file):
    with open(file, 'r') as file:
        json_log = json.load(file)
    return json_log

In [4]:
lines = ['Fe XIX 108.355', 'Fe IX 171.073', 'Fe XV 284.163']
res = xr.Dataset(
    {'num':(['step', 'line', 'metric', 'parameter', 'diffr'], 
               np.full((3, 3, 4, 3, 3), np.nan)), 
    }, 
    coords={'step': ['step0', 'step1', 'step2'], 
            'line': lines,
            'metric':['outlier', '1sig', '2sig', '3sig'], 
            'parameter':['net_flux', 'velocity', 'linewidth'], 
            'diffr':['w/ diff.', 'diff. rm', 'only_core']
           }
)

In [5]:
metric_str = {'outlier':'outliers', '1sig':'sigma1', '2sig':'sigma2', '3sig':'sigma3'}

In [6]:
files = glob.glob(base_path+'/**/*benchmark.json')
files

['/Users/khcho/MUSE/diffraction/time_series/plot0/plot_diff_removed_time/sdc_benchmark.json',
 '/Users/khcho/MUSE/diffraction/time_series/plot0/plot_diff_time/sdc_benchmark.json',
 '/Users/khcho/MUSE/diffraction/time_series/plot0/plot_only_core_time/sdc_benchmark.json']

In [7]:
files = glob.glob(base_path+'/**/*benchmark.json')
for file in files:
    log = read_sdc_log(file)
    step = res.step[2] if 'sub_cont' in file else (res.step[1] if 'inv' in file else res.step[0])
    diffr = 'only_core' if 'only_core' in file else ('diff. rm' if 'removed' in file else 'w/ diff.')
    print(file, diffr)
    for parameter in res.parameter.values:
        for metric in res.metric.values:
            for line in res.line.values:
                words = [parameter, 
                         metric_str[metric], 
                         line]
                has_keyword = [v for k, v in log.items() if all(word in k for word in words)]
                if len(has_keyword) > 1: print('duplicated')
                res['num'].loc[{
                    'parameter':parameter, 
                    'metric':metric, 
                    'line':line, 
                    'step':step, 
                    'diffr':diffr, 
                    }
                ] = has_keyword[0]

/Users/khcho/MUSE/diffraction/time_series/plot0/plot_diff_removed_time/sdc_benchmark.json diff. rm
/Users/khcho/MUSE/diffraction/time_series/plot0/plot_diff_time/sdc_benchmark.json w/ diff.
/Users/khcho/MUSE/diffraction/time_series/plot0/plot_only_core_time/sdc_benchmark.json only_core


In [8]:
def plot_param_space_table(target, col_dim, row_dim, save_name='param_space.png', 
                           cmap='jet', vmin=None, vmax=None, dpi=200):

    import matplotlib.patheffects as patheffects
    custom_order=['diffr', 'parameter', 'metric', 'line', 'GT_noise', 'step',
                  'fill_value', 'threshold']
    
    col_dim_len = [len(target[i]) for i in col_dim]
    row_dim_len = [len(target[j]) for j in row_dim]
    tot_col = np.prod(col_dim_len)
    tot_row = np.prod(row_dim_len)
    xu = 1.5
    yu = 1.2
    xlim = np.array([-len(row_dim)*2, tot_col])*xu
    ylim = np.array([0, tot_row+len(col_dim)])*yu
    if vmin == None: vmin = target.min() 
    if vmax == None: vmax = target.max() 
    
    array = target.transpose(*row_dim, *col_dim).data.reshape(tot_row, tot_col)
    fact = 60
    ax_xsize = fact*np.diff(xlim)[0]
    ax_ysize = fact*np.diff(ylim)[0]
    xmar = 380
    ymar = 60
    fig_xsize = ax_xsize + xmar
    fig_ysize = ax_ysize + ymar
    # fig, ax = plt.subplots(figsize=(0.5*np.diff(xlim)[0], 1.*0.4*np.diff(ylim)[0]))
    fig, ax = plt.subplots(figsize=(fig_xsize/dpi, fig_ysize/dpi), dpi=dpi)
    plt.axis('off')
    table = ax.imshow(array, 
                      extent=[0, tot_col*xu, 0, tot_row*yu], 
                      cmap=cmap, 
                      interpolation='nearest')
    title = []
    unused_coord = [coord for coord in target.coords if coord not in target.dims]
    rank = {name: i for i, name in enumerate(custom_order)}
    unused_coord = sorted(unused_coord, key=lambda x: rank[x])
    for dum in unused_coord:
        title.append(f'{dum}: {target[dum].values}')
    title = '\n'.join(title)
    ax_pos = [(xmar-30)/fig_xsize, (ymar-30)/fig_ysize, 
              ax_xsize/fig_xsize, ax_ysize/fig_ysize]
    ax.set(
        position=ax_pos, 
           xlim=xlim,
           ylim=ylim, 
           aspect='equal')
    fig.text(30/fig_xsize, ax_pos[1]+ax_pos[3], 
             title, va='top', ha='left', fontsize=10)
    for i in range(tot_row):
        for j in range(tot_col):
            if np.isfinite(array[i, j]):
                if float(array[i, j]) < 1e3:
                    value = f'{array[i, j]:0.1f}'
                else:
                    dig = np.floor(np.log10(array[i, j]))
                    value = f'{array[i, j]/(10**dig):0.1f}E{int(dig)}'
                ax.text((j+0.5)*xu, (tot_row-i-0.5)*yu, value, size=8, 
                        va='center', ha='center', c='black', clip_on=False, 
                        path_effects=[patheffects.withStroke(linewidth=2, foreground='white', capstyle="round")])
    
    for i in range(tot_col+1): 
        ax.plot([i*xu]*2, [0, tot_row*yu], c='k', lw=0.5, clip_on=False, solid_capstyle='butt')    
    for j in range(tot_row+1): 
        ax.plot([0, tot_col*xu], [j*yu]*2, c='k', lw=0.5, clip_on=False, solid_capstyle='butt')    
    
    for n1 in range(len(col_dim)): 
        dim_ind = len(col_dim)-n1-1
        repeat = np.prod(col_dim_len[:dim_ind])
        delta_x = tot_col/repeat/col_dim_len[dim_ind]
        for n2 in range(int(repeat)):
            for n3, val in enumerate(target[col_dim[dim_ind]].values):
                ax.text((n3*delta_x + n2*np.prod(col_dim_len[dim_ind:]) + 0.5*delta_x)*xu,  
                        (tot_row+0.5+n1)*yu, 
                        val, va='center', ha='center')
                ax.plot([(n3*delta_x + n2*np.prod(col_dim_len[dim_ind:]))*xu]*2, 
                        [0, (tot_row+n1+1)*yu], 
                        lw=n1+0.5, c='k')
        ax.plot([0, xlim[1]*xu], 
                [(tot_row+1+n1)*yu]*2, lw=0.5+n1, c='k')
    
    for n1 in range(len(row_dim)):           
        dim_ind = len(row_dim)-n1-1
        repeat = np.prod(row_dim_len[:dim_ind])
        delta_y = tot_row/repeat/row_dim_len[dim_ind]
        for n2 in range(int(repeat)):
            for n3, val in enumerate(target[row_dim[dim_ind]].values):
                if len(val) > 10:
                    line_change_pos = val.rfind(' ')
                    val_str = val[:line_change_pos] + '\n' + val[line_change_pos:]
                else:
                    val_str = val
                ax.text((-1-n1*2)*xu, 
                        (tot_row - n3*delta_y - n2*np.prod(row_dim_len[dim_ind:]) - 0.5*delta_y)*yu,
                        val_str, va='center', ha='center')
                ax.plot(np.array([-(n1+1)*2, tot_col])*xu, 
                        np.array([tot_row - n3*delta_y - n2*np.prod(row_dim_len[dim_ind:])]*2)*yu, 
                        lw=n1+0.5, c='k')
        ax.plot([-2*(n1+1)*xu]*2, [0, tot_row*yu], lw=0.5+n1, c='k')

    ax.plot(xlim[[0, 0, 1, 1, 0]], ylim[[0, 1, 1, 0, 0]], c='k', lw=n1+1.5, clip_on=False)

    ax.text((-2+0.1)*xu, (tot_row+0.1)*yu, 
            row_dim[-1], 
            va='bottom', ha='left', size=7)
    t01 = ax.text(-0.1*xu, (tot_row+0.9)*yu, 
            col_dim[-1], 
            va='top', ha='right', size=7)
    ax.plot(np.array([-2, 0])*xu, np.array([tot_row+1, tot_row])*yu, c='k', lw=0.5)
    ax.plot(np.array([-2, -2])*xu, np.array([tot_row+1, tot_row])*yu, c='k', lw=0.5)
    ax.plot(np.array([0, -2])*xu, np.array([tot_row+1, tot_row+1])*yu, c='k', lw=0.5)
    cax = fig.add_subplot(position=[(ax_pos[0]*fig_xsize-70)/fig_xsize, 
                                    (ax_pos[1]*fig_ysize)/fig_ysize, 
                                    30/fig_xsize, 200/fig_ysize])
    cbar = fig.colorbar(table, cax=cax)
    cbar.ax.yaxis.set_ticks_position('left')
    cbar.ax.yaxis.set_label_position('left')
    if save_name is not None:
        fig.savefig(save_name, dpi=dpi)
        plt.close()
    

In [9]:
save_path = base_path + '/tables'
try:
    os.mkdir(save_path)
except:
    pass

for param in res.parameter.values:
    for met in res.metric.values:
        for line in res.line.values:
            target = res['num'].sel(parameter=param, metric=met, line=line)
            col_dim = ['step']
            row_dim = ['diffr']
    
            plot_param_space_table(target, col_dim, row_dim, 
                                  save_name=save_path+f'/{param}_{met}_{line}.png')


In [10]:
save_path = base_path + '/tables2'
try:
    os.mkdir(save_path)
except:
    pass

for met in res.metric.values:
    target = res['num'].sel(metric=met)
    col_dim = ['parameter', 'step']
    row_dim = ['line', 'diffr']

    plot_param_space_table(target, col_dim, row_dim, 
                          save_name=save_path+f'/{met}.png')
