In [None]:
import copy
import glob
import numpy as np
import os
import pandas as pd
import pickle
import scipy
import scipy.stats
import tqdm
import warnings

In [None]:
import yt
import trident
import unyt

In [None]:
import kalepy as kale

In [None]:
import trove
import verdict

In [None]:
import matplotlib
import matplotlib.pyplot as plt
matplotlib.style.use( '/Users/zhafen/repos/clean-bold/clean-bold.mplstyle' )
import palettable
import matplotlib.patheffects as path_effects

In [None]:
import helpers

# Parameters

In [None]:
variations = [ 
    'original',
    'high-z',
]

In [None]:
params = {
    # Analysis 
    'prop_keys': [ 'vlos', 'T', 'nH', 'Z' ],
    'vel_prop_keys': [ 'vlos', 'T', 'nH', 'Z', 'NHI' ],
    'broaden_models': True,
    '1D_dist_estimation': 'kde',
    '1D_dist_estimation_data': 'histogram',
    '2D_dist_estimation': 'histogram',
    'export_data_for_proposal': False,
    
    # Plotting Choices
    'smooth_2D_dist': 0.5,
    'upsample_2D_dist': 3,
    '2D_dist_data_display': 'histogram',
    'contour_levels': [ 90, 50 ],
    'contour_linewidths': [ 1, 3 ],
    'show_plots_in_nb': False,
}

## Analysis

In [None]:
correlation_coefficients = {
    'one-sided': {},
    'log one-sided': { 'logscale': True, 'subtract_mean': True },
    'two-sided': { 'one_sided': False, },
    'linear': { 'one_sided': False, 'subtract_mean': True },
    'log': { 'logscale': True, 'one_sided': False, 'subtract_mean': True },
}

In [None]:
lims = {
    'vlos': [ -300, 300 ],
    'T': [ 1e2, 2.5e6 ],
    'nH': [ 1e-7, 100 ],
    'Z': [ 1e-3, 30 ],
    'NHI': [ 1e9, 1e17 ],
}
autolims = {
    'vlos': False,
    'T': False,
    'nH': False,
    'Z': False,
    'NHI': False,
}

In [None]:
lims_1D = {
    'vlos': [ 3e9, 2e16 ],
    'T': [ 1e12, 1e20 ],
    'nH': [ 1e12, 1e20 ],
    'Z': [ 1e12, 1e20 ],
    'NHI': [ 1e12, 1e20 ],
}

In [None]:
dvs = {
    'vlos': 5.,
    'T': 0.05,
    'nH': 0.05,
    'Z': 0.05,
    'NHI': 0.05,
}

In [None]:
logscale = {
    'vlos': False,
    'T': True,
    'nH': True,
    'Z': True,
    'NHI': True,
}
variation_colors = {}

## Plotting

In [None]:
variation_plotting_params = {
    'original': {
        'color': helpers.modeled_color,
        'label': 'modeled',
    },
    'high-z': {
        'color': helpers.revised_color,
        'label': 'revised',
    },
}

In [None]:
labels = {
    'vlos': r'$v_{\rm LOS}$ [km/s]',
    'T': r'T [K]',
    'nH': r'$n_{\rm H}$ [cm$^{-3}$]',
    'Z': r'$Z$ [$Z_{\odot}$]',
    'NHI': r'$N_{\rm H\,I}$ [cm$^{-2}$]',
}
labels_1D = {
    'vlos': r'$\frac{ d N_{\rm H\,I} }{d v_{\rm LOS}}$',
    'T': r'$\frac{ d N_{\rm H\,I} }{d \log T}$',
    'nH': r'$\frac{ d N_{\rm H\,I} }{d \log n_{\rm H}}$',
    'Z': r'$\frac{ d N_{\rm H\,I} }{d \log Z}$',
    'NHI': r'$\frac{ d N_{\rm H\,I} }{d \log N_{\rm H\,I}}$',
}
r_labels = {}
for key, item in labels.items():
    unitless_label = item.split( '[' )[0]
    r_labels[key] = r'$r($ ' + unitless_label + r'$)$'
r_labels['all'] = r'$r($ all $)$'

In [None]:
correlation_markers = {
    'one-sided': '^',
    'log one-sided': '^',
    'two-sided': 'D',
    'linear': 'o',
    'log': 'o',
}
correlation_sizes = {
    'one-sided': 100,
    'log one-sided': 100,
    'two-sided': 80,
    'linear': 100,
    'log': 100,
}
correlations_plotted = [ 'linear', 'log' ]

In [None]:
mosaic = [
    [ 'vlos', 'legend', '.', '.' ],
    [ 'T_vlos', 'T', '.', '.' ],
    [ 'nH_vlos', 'nH_T', 'nH', '.' ],
    [ 'Z_vlos', 'Z_T', 'Z_nH', 'Z', ],
]
velocity_mosaic = [
    [ 'nH_vlos', 'vlos', ],
    [ 'Z_vlos', 'T_vlos', ],
]

In [None]:
panel_length = 4.

In [None]:
cmap = palettable.cartocolors.qualitative.Safe_10.mpl_colors
corr_cmap = palettable.cartocolors.diverging.Temps_2_r.mpl_colormap

In [None]:
corr_norm = matplotlib.colors.Normalize( vmin=0, vmax=1 )

In [None]:
def one_color_linear_cmap( color, name, f_white=0.95, f_saturated=1.0, ):
    '''A function that turns a single color into linear colormap that
    goes from a color that is whiter than the original color to a color
    that is more saturated than the original color.
    '''
    
    color_hsv = matplotlib.colors.rgb_to_hsv( color )
    start_color_hsv = copy.copy( color_hsv )
    
    start_color_hsv = copy.copy( color_hsv )
    start_color_hsv[1] -= f_white * start_color_hsv[1]
    start_color_hsv[2] += f_white * ( 1. - start_color_hsv[2] )
    start_color = matplotlib.colors.hsv_to_rgb( start_color_hsv )
    
    end_color_hsv = copy.copy( color_hsv )
    end_color_hsv[1] += f_saturated * ( 1. - end_color_hsv[1] )
    end_color = matplotlib.colors.hsv_to_rgb( end_color_hsv )
    
    return matplotlib.colors.LinearSegmentedColormap.from_list( name, [ start_color, end_color ] )

## Process analysis parameteres

In [None]:
# Load parameters
pms = {}
for variation in variations:
    pm = trove.link_params_to_config(
        '/Users/zhafen/analysis/cgm_modeling_challenge/sample2.trove',
        script_id = 'nb.2',
        variation = variation,
        global_variation = '',
        **params
    )
    pms[variation] = pm
pm = list( pms.values() )[0]

# Load Data

In [None]:
# Data structure for storing correlations
correlations_all = {}
for variation, pm in pms.items():
    correlations_fp = os.path.join( pm['data_dir'], 'correlation.h5' )
    correlations_all[variation] = verdict.Dict.from_hdf5( correlations_fp, create_nonexistent=True )

# Plots

## Compare Sample2 initial to Sample2 revised

In [None]:
n_sls = len( correlations_all[variation]['linear']['ndim'] )
xs = np.linspace( -0.5, 0.5, n_sls ) / 2

In [None]:
clean_mosaic = [
    # [ 'all', 'all', 'all', 'legend' ],
    [ 'vlos', 'vlos', 'T', 'T', ],
    [ 'nH', 'nH', 'Z', 'Z', ],
]

In [None]:
# Setup Figure
n_rows_clean = len( clean_mosaic )
n_cols_clean = 3.5
fig = plt.figure( figsize=(n_cols_clean*panel_length, n_rows_clean*panel_length), facecolor='w' )
ax_dict = fig.subplot_mosaic(
    clean_mosaic,
    gridspec_kw = { 'wspace': 0.7 },
)
ax_dict['legend'] = ax_dict['vlos']

def r_scatter( ax, ys, c_key, color=None, label_tag=None ):
    c_params = correlation_coefficients[c_key]
    if 'logscale' in c_params:
        if c_params['logscale']:
            facecolors = 'none'
    else:
        facecolors = color
        
    scatter = ax.scatter(
        xs,
        ys,
        label = '{}, {}'.format( label_tag, c_key ),
        edgecolors = color,
        facecolors = facecolors,
        marker = correlation_markers[c_key],
        s = correlation_sizes[c_key],
        linewidth = 2,
    )

    
# Overall
for variation, pm in pms.items():
    
    correlations = correlations_all[variation]
    plotting_params = variation_plotting_params[variation]
    
    # for c_key in correlations_plotted:
    #     r_scatter(
    #         ax_dict['all'],
    #         correlations[c_key]['ndim'].array(),
    #         c_key,
    #         color = plotting_params['color'],
    #     )

    # Each property
    for j, x_key in enumerate( tqdm.tqdm( pm['prop_keys'], bar_format=pm['bar_format'] ) ):

        ax = ax_dict[x_key]

        for c_key in correlations_plotted:
            r_scatter(
                ax,
                correlations[c_key]['matrix'].array()[:,j,j],
                c_key,
                color = plotting_params['color'],
                label_tag = plotting_params['label'],
            )
    
        
# Add a legend
h, l = ax_dict['vlos'].get_legend_handles_labels()
legend = ax_dict['legend'].legend(
    h,
    l,
    loc = 'lower left',
    prop = {'size': 14},
    ncol = 2,
    framealpha = 1,
)
# ax_dict['legend'].axis( 'off' )
# ax_dict['legend'].annotate(
#     text = r'$r = \frac{ \langle {\rm actual } \vert  {\rm found } \rangle }{ \vert {\rm actual} \vert \vert {\rm found } \vert }$',
#     xy = ( 0, 1 ),
#     xycoords = 'axes fraction',
#     xytext = ( 5, -5 ),
#     textcoords = 'offset points',
#     ha = 'center',
#     va = 'top',
#     fontsize = 18,
# )
        
# Cleanup
for x_key, ax in ax_dict.items():
    
    if x_key in [ 'legend', 'empty' ]:
        continue
    
    subplotspec = ax.get_subplotspec()
    
    for value in [ -1, 0, 1 ]:
        ax.axhline(
            value,
            color = pm['background_linecolor'],
            linewidth = 1,
            zorder = -100,
        )
        
    ax.set_ylabel( r_labels[x_key], fontsize=16 )
    if subplotspec.is_last_row():
        ax.set_xlabel( 'sightline ID', fontsize=16 )
        
    ax.set_xticks( xs )
    xtick_labels = [ _[-2:] for _ in correlations[c_key]['ndim'].keys_array() ]
    ax.set_xticklabels( xtick_labels )
        
    ax.set_ylim( -0.3, 1.1 )
    
# Save
savedir = pm['figure_dir']
os.makedirs( savedir, exist_ok=True )
savefile = 'correlations.pdf'
save_fp = os.path.join( savedir, savefile )
print( 'Saving figure to {}'.format( save_fp ) )
plt.savefig( save_fp, bbox_inches='tight' )