In [None]:
import copy
import glob
import numpy as np
import os
import pandas as pd
import pickle
import scipy
import scipy.stats
import tqdm
import warnings

In [None]:
import yt
import trident
import unyt

In [None]:
import kalepy as kale

In [None]:
import trove
import verdict

In [None]:
import matplotlib
import matplotlib.pyplot as plt
# # Currently need to call this to get matplotlib selected style to load...
plt.plot()
matplotlib.style.use( '/Users/zhafen/repos/clean-bold/clean-bold-mnras.mplstyle' )
import palettable
import matplotlib.patheffects as path_effects

In [None]:
import helpers

# Parameters

In [None]:
summary_data_fp = './data/polished_data/summary.h5'
figure_dir = '/Users/zhafen/drafts/cgm_modeling_challenge_paper/figures'

In [None]:
pm = {
    'jitter_width': 0.2,
}

# Load Data

In [None]:
summary = verdict.Dict.from_hdf5( summary_data_fp, create_nonexistent=True )

# Averages

In [None]:
rng = np.random.default_rng()

In [None]:
mosaic = [ [ 'Z', ], [ 'T', ], [ 'nH', ], ]

In [None]:
n_rows = len( mosaic )
n_cols = len( mosaic[0] )
panel_width = matplotlib.rcParams['figure.figsize'][0]
fig = plt.figure( figsize=(n_cols*panel_width, n_rows*panel_width/2.), facecolor='w' )

main_ax = plt.gca()
main_ax.tick_params( left=False, labelleft=False, bottom=False, labelbottom=False )
for spine in main_ax.spines.values():
    spine.set_visible( False )

ax_dict = fig.subplot_mosaic(
    mosaic,
    gridspec_kw = { 'hspace': 0.15, },
)

# General changes
for ax_key, ax in ax_dict.items():
    # ylim = helpers.lims[ax_key]
    # if helpers.logscale[ax_key]:
    #     ylim = np.log10( ylim )
    
    ax.set_xlim( -0.5, 2.5 )
    # ax.set_ylim( ylim )
    
    ax.set_xticks( [ 0, 1, 2 ] )
    ax.set_xticklabels( [ 'sample0', 'sample1', 'sample2' ], fontname='monospace' )
    
    if ax.get_subplotspec().is_first_row():
        top_ax = ax.twiny()
        top_ax.set_xlim( ax.get_xlim() )
        top_ax.set_xticks( [ 0, 1, 2 ] )
        top_ax.set_xticklabels( [ 'uniform clouds', 'multiple uniform clouds', 'cloud distribution' ], fontsize='x-small' )
        top_ax.tick_params( top=False, pad=0 )
    
    ax.tick_params( bottom=False, pad=0 )
    
    ax.set_ylabel( helpers.property_labels[ax_key] )
    
xs_set = False
for ax_key, ax in ax_dict.items():
    
    ########################################################
    # sample0
    ########################################################
    
    # Plot sample0 averages
    ys = summary['sample0']['source'][ax_key]
    if not xs_set:
        # xs = rng.normal( 0, pm['jitter_width'], ys.size )
        xs = np.linspace( -pm['jitter_width'], pm['jitter_width'], ys.size )
        xs_set = True
    s = ax.scatter(
        xs,
        ys,
        s = plt.rcParams['lines.markersize'] * 2,
        color = 'k',
        label = 'source',
    )
    
    # Plot sample0 blinded averages
    ys = summary['sample0']['estimated']['blinded']['mle'][ax_key]

    s = ax.scatter(
        xs,
        ys,
        s = plt.rcParams['lines.markersize'],
        color = helpers.blinded_color,
        label = 'estimated\u2014blinded',
    )
    
    # Plot sample0 revised averages
    ys = summary['sample0']['estimated']['revised']['mle'][ax_key]
    s = ax.scatter(
        xs,
        ys,
        s = plt.rcParams['lines.markersize'],
        color = helpers.revised_color,
        label = 'estimated\u2014revised',
    )
    
    # Draw connecting lines
    for i, x in enumerate( xs ):
        ys = np.array([
            summary['sample0']['source'][ax_key][i],
            summary['sample0']['estimated']['blinded']['mle'][ax_key][i],
            summary['sample0']['estimated']['revised']['mle'][ax_key][i],
        ])
        ax.plot(
            [ x, x ],
            [ ys.min(), ys.max() ],
            color = '0.8',
            zorder = -10,
            linewidth = 1,
        )
        
    if ax.get_subplotspec().is_first_row():
        ax.legend(
            loc = 'upper left',
            prop = { 'size': 6, },
            edgecolor = 'none',
        )
        
    ########################################################
    # sample1
    ########################################################
    
    xs1 = xs + 1.
    
    # Actual values
    ys_dict = summary['sample1']['source'][ax_key]
    ys_actual = []
    xs_for_clouds = []
    for j, sl in enumerate( ys_dict.keys_array() ):
        ys_sl = ys_dict[sl]
        ys_actual.append( ys_sl )
        xs_for_clouds += [ xs1[j], ] * len(ys_sl )
    ax.scatter(
        np.array( xs_for_clouds ),
        np.concatenate( ys_actual ),
        s = plt.rcParams['lines.markersize'] * 2,
        color = 'k',
    )
    
    # Estimates
    ys_dict = summary['sample1']['estimated']['maximum likelihood estimate'][ax_key]
    ys_estimated = []
    xs_for_clouds = []
    for j, sl in enumerate( ys_dict.keys_array() ):
        ys_sl = ys_dict[sl].array()
        ys_estimated.append( ys_sl )
        xs_for_clouds += [ xs1[j], ] * len( ys_sl )
    ax.scatter(
        np.array( xs_for_clouds ),
        np.concatenate( ys_estimated ),
        s = plt.rcParams['lines.markersize'],
        color = helpers.blinded_color,
    )
    
    # Draw connecting lines
    for i, x in enumerate( xs1 ):
        
        # Throw out the one exception that's already discussed in the text.
        # The really hot, low-metallicity component in sightline 50
        if i == 6:
            ys_actual[i] = np.array([ ys_actual[i][0], ys_actual[i][2] ])
        
        # Find closest matches
        diff_arr = np.abs( ys_actual[i][:,np.newaxis] - ys_estimated[i] )
        inds_matched = np.argmin( diff_arr, axis=1 )

        for j, y_actual in enumerate( ys_actual[i] ):
            ax.plot(
                [ x, x ],
                [ y_actual, ys_estimated[i][inds_matched[j]] ],
                color = '0.8',
                zorder = -10,
                linewidth = 1,
            )
            
    ########################################################
    # sample1
    ########################################################
    
    xs2 = xs + 2.

In [None]:
summary['sample1']