In [None]:
import os
import sys
import cProfile
import numpy as np
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import scipy.stats as stats
import n2j.inference.infer_utils as iutils
from n2j.inference.inference_manager import InferenceManager
from n2j.config_utils import get_config_modular
import corner
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
plt.rcParams.update(plt.rcParamsDefault)
plt.rc('font', family='STIXGeneral', size=20)
import matplotlib
matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'
#plt.rc('text', usetex=False)

In [None]:
# plotting E1-3 (switching E1 and E2 names

#experiments = [2, 1, 3]
experiments = [3, 1, 2]
chain_paths = [f'/home/jwp/stage/sl/node-to-joy/experiments/E{num}/inference_2/omega_chain.h5' for num in experiments]

fig2 = plt.figure(figsize = [10, 10])

# https://colorbrewer2.org/#type=diverging&scheme=RdYlBu&n=5
#colors = ['#f03b20', '#feb24c', '#ffeda0', ] #['#d6616b', '#61bfd6', '#59cfbf', '#e6cb6c']
colors = [ '#de2d26', '#fc9272','#fee0d2', ]

#sample_labels = ['E2 (1 yr)', 'E1 (5 yr)', 'E3 (10 yr)', ] #using labelling from paper
sample_labels = ['E3 (low noise)','E2 (medium noise)', 'E1 (high noise)', ]

for i in range(3): #len(experiments)
    exp = experiments[i]
    cfg = get_config_modular([f'/home/jwp/stage/sl/node-to-joy/experiments_configs/config_E{exp}_local.yml'])
    infer_obj = InferenceManager(checkpoint_dir=cfg['trainer']['checkpoint_dir'],
                                 **cfg['inference_manager'])
    chain_path = os.path.join(infer_obj.out_dir, 'omega_chain.h5')
    if i == 1:
        # Plot summary stats, unweighted N
        infer_obj.visualize_omega_post(log_idx=1,
                               chain_path=f'/home/jwp/stage/sl/node-to-joy/experiments/E1/inference_2/omega_chain_N.h5',
                               corner_kwargs=dict(
                                                 range=[[0, 0.08], [0, 0.03]],
                                                 color='#b5cf6b',
                                                 smooth=1.0,
                                                 #alpha=0.3,
                                                 truths=np.array([0.04, 0.005]),
                                                 label_kwargs={'fontsize': 30},
                                                 labels=[r'$\mu (\{\kappa\})$', r'$\sigma (\{\kappa\})$'],
                                                 labelpad=-0.1,
                                                 fill_contours=True,
                                                 plot_datapoints=False,
                                                 plot_contours=True,
                                                 show_titles=True,
                                                 levels=[0.68, 0.95],
                                                 truth_color='k',
                                                 contour_kwargs=dict(linestyles='solid', colors='#b5cf6b'),
                                      #contourf_kwargs=dict(colors='#843c39'),
                                                 quiet=True,
                                                 #quantiles=[0.5 - 0.34, 0.5 + 0.34],
                                                 title_fmt=".2g",
                                                 fig=fig2,
                                                 title_kwargs={'fontsize': 18},
                                                 #range=[0.99]*len(cols_to_plot),
                                                 use_math_text=True,
                                                 hist_kwargs=dict(density=True, 
                                                                  histtype='step', fill=True, alpha =0.7, #how do I make fill transparent?
                                                                  #changed from stepfilled. alpha=0.6                                                                
                                                                  linewidth=2),
                                                 hist2d_kwargs={'alpha': 1},
                                                 ),
                               chain_kwargs=dict(
                                                 flat=True,
                                                 thin=1,
                                                 discard=90
                                                 ))
        # Plot summary stats, dist-weighted N
        infer_obj.visualize_omega_post(log_idx=1,
                               chain_path=f'/home/jwp/stage/sl/node-to-joy/experiments/E1/inference_2/omega_chain_N_inv_dist.h5',
                               corner_kwargs=dict(
                                                 range=[[0, 0.08], [0, 0.03]],
                                                 color='#637939',
                                                 smooth=1.0,
                                                 #alpha=0.3,
                                                 truths=np.array([0.04, 0.005]),
                                                 label_kwargs={'fontsize': 30},
                                                 labels=[r'$\mu (\{\kappa\})$', r'$\sigma (\{\kappa\})$'],
                                                 labelpad=-0.1,
                                                 fill_contours=True,
                                                 plot_datapoints=False,
                                                 plot_contours=True,
                                                 show_titles=True,
                                                 levels=[0.68, 0.95],
                                                 truth_color='k',
                                                 contour_kwargs=dict(linestyles='solid', colors='#637939'),
                                      #contourf_kwargs=dict(colors='#843c39'),
                                                 quiet=True,
                                                 #quantiles=[0.5 - 0.34, 0.5 + 0.34],
                                                 title_fmt=".2g",
                                                 fig=fig2,
                                                 title_kwargs={'fontsize': 18},
                                                 #range=[0.99]*len(cols_to_plot),
                                                 use_math_text=True,
                                                 hist_kwargs=dict(density=True, 
                                                                  histtype='step', fill=True, alpha =0.7, #how do I make fill transparent?
                                                                  #changed from stepfilled. alpha=0.6                                                                
                                                                  linewidth=2),
                                                 hist2d_kwargs={'alpha': 1},
                                                 ),
                               chain_kwargs=dict(
                                                 flat=True,
                                                 thin=1,
                                                 discard=100
                                                 ))

    infer_obj.visualize_omega_post(log_idx=1,
                               chain_path=chain_path,
                               corner_kwargs=dict(
                                                 range=[[0, 0.08], [0, 0.03]],
                                                 color=colors[i],
                                                 smooth=1.0,
                                                 #alpha=0.3,
                                                 truths=np.array([0.04, 0.005]),
                                                 label_kwargs={'fontsize': 30},
                                                 labels=[r'$\mu (\{\kappa\})$', r'$\sigma (\{\kappa\})$'],
                                                 labelpad=-0.1,
                                                 fill_contours=True,
                                                 plot_datapoints=False,
                                                 plot_contours=True,
                                                 show_titles=True,
                                                 levels=[0.68, 0.95],
                                                 truth_color='k',
                                                 contour_kwargs=dict(linestyles='solid', colors=colors[i]),
                                      #contourf_kwargs=dict(colors='#843c39'),
                                                 quiet=True,
                                                 #quantiles=[0.5 - 0.34, 0.5 + 0.34],
                                                 title_fmt=".2g",
                                                 fig=fig2,
                                                 title_kwargs={'fontsize': 18},
                                                 #range=[0.99]*len(cols_to_plot),
                                                 use_math_text=True,
                                                 hist_kwargs=dict(density=True, 
                                                                  histtype='step', fill=True, alpha =0.7, #how do I make fill transparent?
                                                                  #changed from stepfilled. alpha=0.6                                                                
                                                                  linewidth=2),
                                                 hist2d_kwargs={'alpha': 1},
                                                 ),
                               chain_kwargs=dict(
                                                 flat=True,
                                                 thin=1,
                                                 discard=90
                                                 ))
    
legend_elements = [
    Patch(facecolor='#b5cf6b', alpha=1.0, label=r'N'),
    Patch(facecolor='#637939', alpha=1.0, label=r'inverse-dist N'),
    Patch(facecolor=colors[2], alpha=1.0, label=r'E1 (low SNR)'),
    Patch(facecolor=colors[1], alpha=1.0, label=r'E2 (medium SNR)'),
    Patch(facecolor=colors[0], label=r'E3 (high SNR)'),
    Line2D([0], [0], color='k', lw=2, label='Truth')
               ]
plt.legend(handles=legend_elements, fontsize=20, loc=[0, 1.3])
for ax in fig2.get_axes():
    ax.tick_params(axis='both', labelsize=18)
fig2.savefig('contour_E123_ss.pdf', pad_inches=0.05, bbox_inches='tight')
fig2.savefig('contour_E123_ss.png', pad_inches=0.05, bbox_inches='tight', dpi=100)
#plt.show()
#plt.xticks(fontsize=12)


In [None]:
# plotting E1-3 (switching E1 and E2 names
plt.close('all')
#experiments = [2, 1, 3]
experiments = [3, 1, 2]
chain_paths = [f'/home/jwp/stage/sl/node-to-joy/experiments/E{num}/inference_2/omega_chain.h5' for num in experiments]

fig2 = plt.figure(figsize = [10, 10])

# https://colorbrewer2.org/#type=diverging&scheme=RdYlBu&n=5
#colors = ['#f03b20', '#feb24c', '#ffeda0', ] #['#d6616b', '#61bfd6', '#59cfbf', '#e6cb6c']
colors = [ '#de2d26', '#fc9272','#fee0d2', ]

#sample_labels = ['E2 (1 yr)', 'E1 (5 yr)', 'E3 (10 yr)', ] #using labelling from paper
sample_labels = ['E3 (low noise)','E2 (medium noise)', 'E1 (high noise)', ]

for i in [1]: #len(experiments)
    exp = experiments[i]
    cfg = get_config_modular([f'/home/jwp/stage/sl/node-to-joy/experiments_configs/config_E{exp}_local.yml'])
    infer_obj = InferenceManager(checkpoint_dir=cfg['trainer']['checkpoint_dir'],
                                 **cfg['inference_manager'])
    chain_path = os.path.join(infer_obj.out_dir, 'omega_chain.h5')
    if True:
        # Plot summary stats, unweighted N
        infer_obj.visualize_omega_post(log_idx=1,
                               chain_path=f'/home/jwp/stage/sl/node-to-joy/experiments/E1/inference_2/omega_chain_N.h5',
                               corner_kwargs=dict(
                                                 range=[[0, 0.08], [0, 0.03]],
                                                 color='#b5cf6b',
                                                 smooth=1.0,
                                                 #alpha=0.3,
                                                 truths=np.array([0.04, 0.005]),
                                                 label_kwargs={'fontsize': 30},
                                                 labels=[r'$\mu (\{\kappa\})$', r'$\sigma (\{\kappa\})$'],
                                                 labelpad=-0.1,
                                                 fill_contours=True,
                                                 plot_datapoints=False,
                                                 plot_contours=True,
                                                 show_titles=True,
                                                 levels=[0.68, 0.95],
                                                 truth_color='k',
                                                 contour_kwargs=dict(linestyles='solid', colors='#b5cf6b'),
                                      #contourf_kwargs=dict(colors='#843c39'),
                                                 quiet=True,
                                                 #quantiles=[0.5 - 0.34, 0.5 + 0.34],
                                                 title_fmt=".2g",
                                                 fig=fig2,
                                                 title_kwargs={'fontsize': 18},
                                                 #range=[0.99]*len(cols_to_plot),
                                                 use_math_text=True,
                                                 hist_kwargs=dict(density=True, 
                                                                  histtype='step', fill=True, alpha =0.7, #how do I make fill transparent?
                                                                  #changed from stepfilled. alpha=0.6                                                                
                                                                  linewidth=2),
                                                 hist2d_kwargs={'alpha': 1},
                                                 ),
                               chain_kwargs=dict(
                                                 flat=True,
                                                 thin=1,
                                                 discard=90
                                                 ))
        # Plot summary stats, dist-weighted N
        infer_obj.visualize_omega_post(log_idx=1,
                               chain_path=f'/home/jwp/stage/sl/node-to-joy/experiments/E1/inference_2/omega_chain_N_inv_dist.h5',
                               corner_kwargs=dict(
                                                 range=[[0, 0.08], [0, 0.03]],
                                                 color='#637939',
                                                 smooth=1.0,
                                                 #alpha=0.3,
                                                 truths=np.array([0.04, 0.005]),
                                                 label_kwargs={'fontsize': 30},
                                                 labels=[r'$\mu (\{\kappa\})$', r'$\sigma (\{\kappa\})$'],
                                                 labelpad=-0.1,
                                                 fill_contours=True,
                                                 plot_datapoints=False,
                                                 plot_contours=True,
                                                 show_titles=True,
                                                 levels=[0.68, 0.95],
                                                 truth_color='k',
                                                 contour_kwargs=dict(linestyles='solid', colors='#637939'),
                                      #contourf_kwargs=dict(colors='#843c39'),
                                                 quiet=True,
                                                 #quantiles=[0.5 - 0.34, 0.5 + 0.34],
                                                 title_fmt=".2g",
                                                 fig=fig2,
                                                 title_kwargs={'fontsize': 18},
                                                 #range=[0.99]*len(cols_to_plot),
                                                 use_math_text=True,
                                                 hist_kwargs=dict(density=True, 
                                                                  histtype='step', fill=True, alpha =0.7, #how do I make fill transparent?
                                                                  #changed from stepfilled. alpha=0.6                                                                
                                                                  linewidth=2),
                                                 hist2d_kwargs={'alpha': 1},
                                                 ),
                               chain_kwargs=dict(
                                                 flat=True,
                                                 thin=1,
                                                 discard=100
                                                 ))

    infer_obj.visualize_omega_post(log_idx=1,
                               chain_path=chain_path,
                               corner_kwargs=dict(
                                                 range=[[0, 0.08], [0, 0.03]],
                                                 color=colors[i],
                                                 smooth=1.0,
                                                 #alpha=0.3,
                                                 truths=np.array([0.04, 0.005]),
                                                 label_kwargs={'fontsize': 30},
                                                 labels=[r'$\mu (\{\kappa\})$', r'$\sigma (\{\kappa\})$'],
                                                 labelpad=-0.1,
                                                 fill_contours=True,
                                                 plot_datapoints=False,
                                                 plot_contours=True,
                                                 show_titles=True,
                                                 levels=[0.68, 0.95],
                                                 truth_color='k',
                                                 contour_kwargs=dict(linestyles='solid', colors=colors[i]),
                                      #contourf_kwargs=dict(colors='#843c39'),
                                                 quiet=True,
                                                 #quantiles=[0.5 - 0.34, 0.5 + 0.34],
                                                 title_fmt=".2g",
                                                 fig=fig2,
                                                 title_kwargs={'fontsize': 18},
                                                 #range=[0.99]*len(cols_to_plot),
                                                 use_math_text=True,
                                                 hist_kwargs=dict(density=True, 
                                                                  histtype='step', fill=True, alpha =0.7, #how do I make fill transparent?
                                                                  #changed from stepfilled. alpha=0.6                                                                
                                                                  linewidth=2),
                                                 hist2d_kwargs={'alpha': 1},
                                                 ),
                               chain_kwargs=dict(
                                                 flat=True,
                                                 thin=1,
                                                 discard=90
                                                 ))
    
legend_elements = [
    Patch(facecolor='#b5cf6b', alpha=1.0, label=r'N'),
    Patch(facecolor='#637939', alpha=1.0, label=r'inverse-dist N'),
    #Patch(facecolor=colors[2], alpha=1.0, label=r'E1 (low SNR)'),
    Patch(facecolor=colors[1], alpha=1.0, label=r'E2 (medium SNR)'),
    #Patch(facecolor=colors[0], label=r'E3 (high SNR)'),
    Line2D([0], [0], color='k', lw=2, label='Truth')
               ]
plt.legend(handles=legend_elements, fontsize=20, loc=[0, 1.3])
for ax in fig2.get_axes():
    ax.tick_params(axis='both', labelsize=18)
fig2.savefig('contour_E1_ss.pdf', pad_inches=0.05, bbox_inches='tight')
fig2.savefig('contour_E1_ss.png', pad_inches=0.05, bbox_inches='tight', dpi=100)
#plt.show()
#plt.xticks(fontsize=12)
