In [1]:
import copy
import glob
import numpy as np
import os
import pandas as pd
import scipy
import scipy.stats
import tqdm
import warnings

In [2]:
import yt
import trident
import unyt

In [3]:
import kalepy as kale

In [4]:
import trove

In [5]:
import matplotlib
import matplotlib.pyplot as plt
matplotlib.style.use( '/Users/zhafen/repos/clean-bold/clean-bold.mplstyle' )
import palettable

# Parameters

In [6]:
pm = {
    # Analysis 
    'broaden_models': True,
    '1D_dist_estimation': 'kde',
    '1D_dist_estimation_data': 'histogram',
    '2D_dist_estimation': 'histogram',
    'n_bins_convolve': 16,
    
    # Plotting Choices
    'smooth_2D_dist': 0.5,
    'upsample_2D_dist': 3,
    '2D_dist_data_display': 'histogram',
    'n_bins_1D': 128,
    'n_bins_data_1D': 128,
    'n_bins_2D': 32,
    'n_bins_data_2D': 20,
    'n_sample_turb': 1000,
    'contour_levels': [ 90, 50 ],
    'contour_linewidths': [ 1, 3 ],
}

In [7]:
# Load parameters
pm = trove.link_params_to_config(
    '/Users/zhafen/repos/cgm_modeling_challenge/modeling_challenge.trove',
    script_id = 'nb.2',
    variation = 'sample2',
    **pm
)

In [8]:
redshift = pm['redshift']

In [9]:
data_dir = pm['data_dir']
ray_dir = os.path.join( data_dir, 'rays' )
results_dir = os.path.join( data_dir, 'modeling' )

In [10]:
cmap = palettable.cartocolors.qualitative.Safe_10.mpl_colors

In [11]:
labels = {
    'vlos': r'$v_{\rm LOS}$ [km/s]',
    'T': r'T [K]',
    'nH': r'$n_{\rm H}$ [cm$^{-3}$]',
    'Z': r'$Z$ [$Z_{\odot}$]',
}

In [12]:
lims = {
    'vlos': [ -100, 100 ],
    'T': [ 1e2, 2.5e6 ],
    'nH': [ 1e-7, 100 ],
    'Z': [ 1e-3, 30 ],
}
autolims = {
    'vlos': False,
    'T': False,
    'nH': False,
    'Z': False,
}

In [13]:
lims_1D = {
    'vlos': [ 1e8, 1e18 ],
    'T': [ 1e12, 1e20 ],
    'nH': [ 1e12, 1e20 ],
    'Z': [ 1e12, 1e20 ],
}
labels_1D = {
    'vlos': r'$\frac{ d N_{\rm H\,I} }{d v_{\rm LOS}}$',
    'T': r'$\frac{ d N_{\rm H\,I} }{d \log T}$',
    'nH': r'$\frac{ d N_{\rm H\,I} }{d \log n_{\rm H}}$',
    'Z': r'$\frac{ d N_{\rm H\,I} }{d \log Z}$',
}

In [14]:
dvs = {
    'vlos': 5.,
    'T': 0.05,
    'nH': 0.05,
    'Z': 0.05,
}

In [15]:
logscale = {
    'vlos': False,
    'T': True,
    'nH': True,
    'Z': True,
}

In [16]:
bar_format = pm['bar_format']

# Tools

In [17]:
ldb = trident.LineDatabase(None)

read_sets: Using set file -- 
  /Users/zhafen/repos/linetools/linetools/lists/sets/llist_v1.3.ascii
Loading abundances from Asplund2009
Abundances are relative by number on a logarithmic scale with H=12


In [18]:
def zeff_to_vel( zeff ):
    
    ainv2 = ( ( 1. + zeff ) / ( 1. + redshift ) )**2.
    
    v_div_c = ( ainv2 - 1. ) / ( ainv2 + 1. )
    return ( v_div_c * unyt.c ).to( 'km/s' )

# Load Data

## Modeled

In [19]:
# Get sightline filepaths
sl_fps = []
sls = []
for sl_fp in glob.glob( os.path.join( results_dir, '*' ) ):
    sl_fps.append( sl_fp )
    sls.append( os.path.split( sl_fp )[-1] )

In [20]:
n_sample_turb = pm['n_sample_turb']

In [21]:
sl_datas = []
sl_useds = []
sl_stackeds = []
model_weightss = []
stacked_weightss = []
total_weightss = []
median_NHIs = []
for i, sl in enumerate( sls ):
    
    sl_fp = sl_fps[i]
    
    # Get text files
    col_names = [ 'prob', 'likelihood', 'Z', 'nH', 'T', 'NHI', 'bturb', 'z', ]
    col_units = [ 1., 1., unyt.Zsun, unyt.cm**-3, unyt.K, unyt.cm**-2, unyt.km / unyt.s, 1. ]
    sl_data = {}
    for component_fp in glob.glob( os.path.join( sl_fp, '*' ) ):
        component_key = os.path.splitext( os.path.split( component_fp )[-1] )[0]
        sl_data[component_key] = pd.read_csv( component_fp, sep=' ', names=col_names )
        
    # Add LOS velocity and reformat
    for component_key, df in sl_data.items():

        # Reformat
        new_entry = {}
        for name in col_names:
            values = unyt.unyt_array( df[name].values )
            if name in [ 'nH', 'T', 'Z', 'NHI', ]:
                new_entry[name] = 10.**values
            else:
                new_entry[name] = values

        # Add LOS velocity
        new_entry['vlos'] = zeff_to_vel( df['z'].values )

        # Setup units
        for j, name in enumerate( col_names ):
            new_entry[name] *= col_units[j]

        sl_data[component_key] = new_entry
    col_names.append( 'vlos' )
        
    # Turn samples into a list
    keys = list( sl_data.keys() )
    sl_formatted = {}
    for name in col_names:
        sl_formatted[name] = [ sl_data[_][name] for _ in keys ]

    # Generate modeled sample to plot ("generate" because we're sampling the doppler broadening)
    if pm['broaden_models']:
        sl_tiled = {}
        for name in col_names:
            sl_tiled[name] = []

        for j, vlos_j in enumerate( tqdm.tqdm( sl_formatted['vlos'], bar_format=bar_format ) ):
            sample_dist = scipy.stats.norm( loc=vlos_j, scale=sl_formatted['bturb'][j]/np.sqrt( 2. ) )
            sampled_values = sample_dist.rvs( ( n_sample_turb, vlos_j.size ) )

            for name in col_names:
                if name != 'vlos':
                    arr_tiled = np.hstack( np.tile( sl_formatted[name][j], ( n_sample_turb, 1 ),  ) )
                else:
                    arr_tiled = np.hstack( sampled_values )

                arr_tiled *= sl_formatted[name][j].units

                sl_tiled[name].append( arr_tiled )

        sl_used = sl_tiled
    else:
        sl_used = sl_formatted

    # Minimum weighting is to make sure no component is overweighted due to number of samples
    n_samples_max = np.max([ _.size for _ in sl_used['nH'] ])
    if pm['weighting'] is None:
        model_weights = [ np.full( _.size, n_samples_max / _.size ) for _ in sl_used['nH'] ]
    else:
        model_weights = [ np.full( _.size, np.nanmedian( _ ) * n_samples_max / _.size ) for _ in sl_used[pm['weighting']] ]
    model_weightss.append( model_weights )
    stacked_weightss.append( np.hstack( model_weights ) )
    total_weightss.append( np.array([ _.sum() for _ in model_weights ]) )
    median_NHIs.append( np.array([ np.nanmedian( _ ) for _ in sl_used['NHI'] ]) )

    sl_stacked = {}
    for key, item in sl_used.items():
        sl_stacked[key] = np.hstack( item ) * item[0].units
    
    sl_datas.append( sl_data )
    sl_useds.append( sl_used )
    sl_stackeds.append( sl_stacked )

     100%|██████████| 2/2 [00:01<00:00,  1.22it/s]


## Actual

In [22]:
# Get rays
ray_fps = [ os.path.join( ray_dir, 'ray_{}.h5'.format( _[1:] ) ) for _ in sls ]
rays = [ yt.load( _ ) for _ in ray_fps ]

yt : [INFO     ] 2022-02-02 16:50:32,932 Parameters: current_time              = 0.0
yt : [INFO     ] 2022-02-02 16:50:32,933 Parameters: domain_dimensions         = [1 1 1]
yt : [INFO     ] 2022-02-02 16:50:32,933 Parameters: domain_left_edge          = [0. 0. 0.] kpc
yt : [INFO     ] 2022-02-02 16:50:32,934 Parameters: domain_right_edge         = [95.7322 95.7322 95.7322] kpc
yt : [INFO     ] 2022-02-02 16:50:32,935 Parameters: cosmological_simulation   = 0.0
yt : [INFO     ] 2022-02-02 16:50:32,994 Parameters: current_time              = 0.0
yt : [INFO     ] 2022-02-02 16:50:32,994 Parameters: domain_dimensions         = [1 1 1]
yt : [INFO     ] 2022-02-02 16:50:32,995 Parameters: domain_left_edge          = [0. 0. 0.] kpc
yt : [INFO     ] 2022-02-02 16:50:32,996 Parameters: domain_right_edge         = [95.7322 95.7322 95.7322] kpc
yt : [INFO     ] 2022-02-02 16:50:32,997 Parameters: cosmological_simulation   = 0.0
yt : [INFO     ] 2022-02-02 16:50:33,055 Parameters: current_time   

In [23]:
ray_datas = []
ray_weights = []
for ray in rays:
    # Ray properties
    trident.add_ion_fields(ray, ions=[ 'H I', ], line_database=ldb)
    den = ray.r[('gas', 'number_density')] * 0.75
    ray_data = {
        'nH': den,
        'NH': ( den * ray.r[('gas', 'dl')] ),
        'NHI': ray.r[('gas', 'H_p0_number_density')] * ray.r[('gas', 'dl')],
        'z': ray.r[('gas', 'redshift_eff')],
        'T': ray.r[('gas', 'temperature')],
        'Z': ray.r[('gas', 'metallicity')],
    }
    ray_data['vlos'] = zeff_to_vel( ray_data['z'] )
    
    if pm['sim_weighting'] is None:
        weights = np.ones( ray_data['NH'].shape )
    else:
        weights = copy.copy( ray_data[pm['sim_weighting']].value )
        weights[np.isclose(weights,0.)] = np.nan
    
    ray_datas.append( ray_data )
    ray_weights.append( weights )

yt : [INFO     ] 2022-02-02 16:50:33,500 Allocating for 1.024e+03 particles


## Bins and Limits

In [24]:
all_binss = []
all_dxs = []
all_centerss = []
for i, ray_data in enumerate( tqdm.tqdm( ray_datas, bar_format=bar_format ) ):
    
    sl_stacked = sl_stackeds[i]

    # Make bins, dx, and centers
    all_bins = {}
    all_dx = {}
    all_centers = {}
    for n_bins_key in [ 'n_bins_1D', 'n_bins_2D', 'n_bins_data_1D', 'n_bins_data_2D', 'n_bins_convolve' ]:
        n_bins = pm[n_bins_key]
        bins = {}
        for key, item in lims.items():
            if autolims[key]:
                low = np.nanmin(np.hstack([ sl_stacked[key], ray_data[key] ]))
                high = np.nanmax(np.hstack([ sl_stacked[key], ray_data[key] ]))
            else:
                low = item[0]
                high = item[1]
            if logscale[key]:
                bins[key] = np.logspace( np.log10( low ), np.log10( high ), n_bins )
            else:
                bins[key] = np.linspace( low, high, n_bins )

            bins[key] *= sl_stacked[key].units
        all_bins[n_bins_key] = bins

        dx = {}
        for key, bins_j in bins.items():
            if logscale[key]:
                dx[key] = np.log10( bins_j[1] ) - np.log10( bins_j[0] )
            else:
                dx[key] = float( ( bins_j[1] - bins_j[0] ).value )
        all_dx[n_bins_key] = dx

        centers = {}
        for key, bins_j in bins.items():

            if logscale[key]:
                bins_j = np.log10( bins_j )

            centers[key] = bins_j[:-1] + 0.5 * np.diff( bins_j )

            if logscale[key]:
                centers[key] = 10.**centers[key]
        all_centers[n_bins_key] = centers
        
    all_binss.append( all_bins )
    all_dxs.append( all_dx )
    all_centerss.append( all_centers )

      10%|█         | 1/10 [00:00<00:00, 261.57it/s]


IndexError: list index out of range

# Comparison Metrics

In [None]:
i = 0

In [None]:
ray_data = ray_datas[i]
weights = ray_weights[i]

In [None]:
sl_stacked = sl_stackeds[i]
model_weights = model_weightss[i]
stacked_weights = stacked_weightss[i]

In [None]:
prop_keys = list( labels.keys() )

In [None]:
bins = all_binss[i]['n_bins_convolve']
dx = all_dxs[i]['n_bins_convolve']
centers = all_centerss[i]['n_bins_convolve']

In [None]:
# Format for calculating distributions
ray_data_histdd = []
sl_stacked_histdd = []
bins_histdd = []
for key in prop_keys:
    
    arr_modeled = sl_stacked[key].value
    arr = ray_data[key].value
    bins_key = bins[key]
    
    if logscale[key]:
        arr_modeled = np.log10( arr_modeled )
        arr = np.log10( arr )
        bins_key = np.log10( bins_key )
    
    sl_stacked_histdd.append( arr_modeled )
    ray_data_histdd.append( arr )
    bins_histdd.append( bins_key )
    
sl_stacked_histdd = np.array( sl_stacked_histdd ).transpose()
ray_data_histdd = np.array( ray_data_histdd ).transpose()
bins_histdd = np.array( bins_histdd )

In [None]:
ray_dist_dd, bins_dd = np.histogramdd(
    ray_data_histdd,
    bins = bins_histdd,
    weights = weights,
)
ray_dist_dd /= ray_dist_dd.sum()

In [None]:
modeled_dist_dd, bins_dd = np.histogramdd(
    sl_stacked_histdd,
    bins = bins_histdd,
    weights = stacked_weights,
)
modeled_dist_dd /= modeled_dist_dd.sum()

In [None]:
# Loop through all properties
for j, x_key in enumerate( tqdm.tqdm( prop_keys, bar_format=bar_format ) ):
    for k, y_key in enumerate( prop_keys ):

        # Avoid duplicates
        if k < j:
            continue
            
        

In [None]:
( modeled_dist_dd * ray_dist_dd ).sum() / ( ray_dist_dd**2. ).sum()

# Plot

## Setup

### Parameters

In [None]:
contour_levels = pm['contour_levels']
contour_linewidths = pm['contour_linewidths']

In [None]:
mosaic = [
    [ 'vlos', 'legend', '.', '.' ],
    [ 'T_vlos', 'T', '.', '.' ],
    [ 'nH_vlos', 'nH_T', 'nH', '.' ],
    [ 'Z_vlos', 'Z_T', 'Z_nH', 'Z', ],
]

In [None]:
def one_color_linear_cmap( color, name, f_white=0.95, f_saturated=1.0, ):
    '''A function that turns a single color into linear colormap that
    goes from a color that is whiter than the original color to a color
    that is more saturated than the original color.
    '''
    
    color_hsv = matplotlib.colors.rgb_to_hsv( color )
    start_color_hsv = copy.copy( color_hsv )
    
    start_color_hsv = copy.copy( color_hsv )
    start_color_hsv[1] -= f_white * start_color_hsv[1]
    start_color_hsv[2] += f_white * ( 1. - start_color_hsv[2] )
    start_color = matplotlib.colors.hsv_to_rgb( start_color_hsv )
    
    end_color_hsv = copy.copy( color_hsv )
    end_color_hsv[1] += f_saturated * ( 1. - end_color_hsv[1] )
    end_color = matplotlib.colors.hsv_to_rgb( end_color_hsv )
    
    return matplotlib.colors.LinearSegmentedColormap.from_list( name, [ start_color, end_color ] )

In [None]:
color_modeled = cmap[0]
color_data = cmap[1]
cmap_modeled = one_color_linear_cmap( color_modeled, 'modeled' )
cmap_data = one_color_linear_cmap( color_data, 'data' )
cmap_data = matplotlib.colors.LinearSegmentedColormap.from_list( 'data', [ 'w', color_data ] )

In [None]:
panel_length = 4.

### Other Setup

In [None]:
class ContourCalc( object ):
    
    def __init__( self, arr ):
        
        is_not_nan = np.invert( np.isnan( arr ) )
        is_finite = np.invert( np.isinf( arr ) )
        is_valid = is_not_nan & is_finite
        self.values_sorted = np.sort( arr[is_valid] )[::-1]
        
        self.values_fraction = np.cumsum( self.values_sorted )
        self.values_fraction /= self.values_fraction[-1]
        
        self.interp_fn = scipy.interpolate.interp1d( self.values_fraction, self.values_sorted )
        
    def get_level( self, q, f_min_is_average=True ):
        
        f = np.array( q ) / 100.
        
        if f_min_is_average:
            f_min = 0.5 * ( self.values_fraction[0] + self.values_fraction[1] )
        else:
            f_min = self.values_fraction[0]
        
        if pd.api.types.is_list_like( f ):
            f = np.array( f )
            f[f<f_min] = f_min
        else:
            if f < f_min:
                f = f_min

        return self.interp_fn( f ) 

## Plot

In [None]:
for i, ray in enumerate( rays ):
    
#     # DEBUG
#     if i != 7:
#         continue
    
    print( '\nMaking comparison for ray {}\n'.format( i ) )

    ray_data = ray_datas[i]
    weights = ray_weights[i]

    # Modeled sightline
    sl_data = sl_datas[i]
    sl_used = sl_useds[i]
    sl_stacked = sl_stackeds[i]
    model_weights = model_weightss[i]
    stacked_weights = stacked_weightss[i]
    total_weights = total_weightss[i]
    median_NHI = median_NHIs[i]
    
    all_bins = all_binss[i]
    all_dx = all_dxs[i]
    all_centers = all_centerss[i]
    
    bins = all_bins['n_bins_1D']
    dx = all_dx['n_bins_1D']
    centers = all_centers['n_bins_1D']
    
    # Setup Figure
    n_cols = len( prop_keys )
    fig = plt.figure( figsize=( panel_length*n_cols, panel_length*n_cols ), facecolor='w' )
    ax_dict = fig.subplot_mosaic( mosaic )

    # Loop through all properties
    for j, x_key in enumerate( tqdm.tqdm( prop_keys, bar_format=bar_format ) ):
        for k, y_key in enumerate( prop_keys ):

            # Avoid duplicates
            if k < j:
                continue
            
#             # DEBUG
#             if j != 1 and k != 0:
#                 continue

            # Check for out-of-bounds
            oob_labels = [ 'modeled', 'ray' ]
            for ii, key in enumerate([ x_key, y_key ]):
                for jj, values in enumerate([ sl_stacked[key], ray_data[key] ]):
                    n_low = ( values < bins[key][0] ).sum()
                    n_high = ( values > bins[key][-1] ).sum()
                    bounds = [ 'below', 'above' ]
                    for kk, n_oob in enumerate([ n_low, n_high ]):
                        if n_oob / values.size > 0.02:
                            warnings.warn(
                                '{} {} points ({:.2g}%) with {} {} {:.3g}'.format(
                                    n_oob,
                                    oob_labels[jj],
                                    n_oob / values.size * 100,
                                    key,
                                    bounds[kk],
                                    lims[key][kk],
                                )
                            )

            # 1D histogram
            if j == k:
                ax = ax_dict[x_key]
                subplotspec = ax.get_subplotspec()

                x_label = labels[x_key]
                y_label = labels_1D[x_key]
                
                bins = all_bins['n_bins_1D']
                dx = all_dx['n_bins_1D']
                centers = all_centers['n_bins_1D']

                # Observational
                if pm['1D_dist_estimation'] == 'histogram':
                    hist_o, edges = np.histogram(
                        sl_stacked[x_key],
                        bins = bins[x_key],
                        weights = stacked_weights,
                        density = False,
                    )
                    hist_o /= hist_o.sum() * dx[x_key]
                    ax.step(
                        edges[:-1],
                        hist_o,
                        color = color_modeled,
                        where = 'post',
                        linewidth = 2,
                    )
                elif pm['1D_dist_estimation'] == 'kde':
                    # Change to logspace for kde
                    if logscale[x_key]:
                        sl_kde = np.log10( sl_stacked[x_key] )
                        kde_centers = np.log10( centers[x_key] )
                    else:
                        sl_kde = sl_stacked[x_key].value
                        kde_centers = centers[x_key].value
                    kde_centers, hist_o = kale.density(
                            sl_kde,
                            points = kde_centers,
                            weights = stacked_weights,
                            probability = False,
                        )
                    hist_o /= hist_o.sum() * dx[x_key]
                    hist_o *= median_NHI.sum()
                    ax.plot(
                        centers[x_key],
                        hist_o,
                        linewidth = 5,
                        color = 'k',
                        label = 'modeled',
                    )
                    
                    # Individual components
                    for kk, sl_used_x in enumerate( sl_used[x_key] ):       
                        if logscale[x_key]:
                            sl_used_x = np.log10( sl_used_x )
                        else:
                            sl_used_x = sl_used_x.value
                        kde_centers, hist_o = kale.density(
                                sl_used_x,
                                points = kde_centers,
                                weights = model_weights[kk],
                                probability = False,
                            )
                        hist_o /= hist_o.sum() * dx[x_key]
                        hist_o *= median_NHI[kk]
                        
                        # Cycle through colors, skipping the color reserved for the ray data
                        if kk == 0:
                            cmap_kk = kk
                        else:
                            cmap_kk = kk + 1
                    
                        ax.plot(
                            centers[x_key],
                            hist_o,
                            linewidth = 2,
                            color = cmap[cmap_kk],
                            label = r'    component $\log N_{\rm H\,I}=$' + '{:.3g}'.format( np.log10( median_NHI[kk] ) ),
                        )

                # Ray
                bins = all_bins['n_bins_data_1D']
                dx = all_dx['n_bins_data_1D']
                centers = all_centers['n_bins_data_1D']
                
                if pm['1D_dist_estimation_data'] == 'histogram':
                    hist_r, edges = np.histogram(
                        ray_data[x_key],
                        bins = bins[x_key],
                        weights = weights,
                        density = False,
                    )
                    hist_r /= hist_r.sum() * dx[x_key]
                    hist_r *= ray_data['NHI'].sum()
                    ax.fill_between(
                        edges[:-1],
                        hist_r,
                        color = color_data,
                        step = 'post',
                        label = 'data',
                    )
                elif pm['1D_dist_estimation_data'] == 'kde':
                    # Change to logspace for kde
                    if logscale[x_key]:
                        sl_kde = np.log10( ray_data[x_key] )
                        kde_centers = np.log10( centers[x_key] )
                    else:
                        sl_kde = ray_data[x_key].value
                        kde_centers = centers[x_key].value
                    kde_centers, hist_r = kale.density(
                            sl_kde,
                            points = kde_centers,
                            weights = weights,
                            probability = False,
                        )
                    hist_r /= hist_r.sum() * dx[x_key]
                    hist_r *= ray_data['NHI'].sum()
                    ax.fill_between(
                        centers[x_key],
                        hist_r,
                        color = color_data,
                    )

#                 y_min = 10.**np.nanmin( [ np.nanmin( np.log10( hist_o[hist_o>0] ) ), np.nanmin( np.log10( hist_r[hist_r>0] ) ) ] )
#                 y_min = lims_1D[x_key]
#                 y_max = np.nanmax([ np.nanmax( hist_r ), np.nanmax( hist_o ) ])
#                 ax.set_ylim( y_min, y_max * 1.05 )
                ax.set_ylim( lims_1D[x_key] )
                ax.set_xlim( bins[x_key][0], bins[x_key][-1] )

                if logscale[x_key]:
                    ax.set_xscale( 'log' )
                ax.set_yscale( 'log' )
                
                if x_key in [ 'T', 'nH', 'Z' ]:
                    ax.yaxis.set_label_position( 'right' )
                    ax.set_ylabel( y_label, fontsize=16 )

                ax.tick_params(
                    which = 'both',
                    labelleft = subplotspec.is_first_col(),
                    right = True,
                    labelright = True,
                )

            # 2D histogram
            else:
                
                bins = all_bins['n_bins_2D']
                dx = all_dx['n_bins_2D']
                centers = all_centers['n_bins_2D']
                
                centers_x = copy.copy( centers[x_key] )
                centers_y = copy.copy( centers[y_key] )
                
                try:
                    ax = ax_dict['{}_{}'.format( x_key, y_key )]
                except KeyError:
                    ax = ax_dict['{}_{}'.format( y_key, x_key )]
                subplotspec = ax.get_subplotspec()
                              
                # Upsample centers
                upsample = pm['upsample_2D_dist']
                if upsample is not None:
                    centers_x = scipy.ndimage.zoom( centers_x, upsample )
                    centers_y = scipy.ndimage.zoom( centers_y, upsample )
                
                # Observational per component
                img_arr_comps = []
                    
                for kk, sl_used_x in enumerate( sl_used[x_key] ):
                    sl_used_y = sl_used[y_key][kk]
                    norm_kk = total_weights[kk] * dx[x_key] * dx[y_key]
                    
                    # Histogram version
                    if pm['2D_dist_estimation'] == 'histogram':
                        hist2d_kk, x_edges, y_edges = np.histogram2d(
                            sl_used_x,
                            sl_used_y,
                            bins = [ bins[x_key], bins[y_key], ],
                            weights = model_weights[kk] / norm_kk,
                        )
                        img_arr_kk = np.transpose( hist2d_kk )
                        
                    # KDE version
                    elif pm['2D_dist_estimation'] == 'kde':
                        
                        # Change to logspace for kde
                        if logscale[x_key]:
                            sl_used_x = np.log10( sl_used_x )
                            kde_centers_x = np.log10( centers_x )
                        else:
                            kde_centers_x = centers[x_key]
                        if logscale[y_key]:
                            sl_used_y = np.log10( sl_used_y )
                            kde_centers_y = np.log10( centers_y )
                        else:
                            kde_centers_y = centers[y_key]
                            
                        kde_data = np.array([ sl_used_x, sl_used_y ])
                        points, img_arr_kk = kale.density(
                            kde_data,
                            points = [ kde_centers_x, kde_centers_y ],
                            grid = True,
                            weights = model_weights[kk] / norm_kk,
                        )
                        
                    # Upsample and smooth
                    if upsample is not None:
                        img_arr_kk = scipy.ndimage.zoom( img_arr_kk, upsample )                                                                                                
                    
                    if pm['smooth_2D_dist'] is not None:
                        if upsample is not None:                                                   
                            sigma = upsample * pm['smooth_2D_dist']
                        else:
                            sigma = pm['smooth_2D_dist']
                        img_arr_kk = scipy.ndimage.filters.gaussian_filter( img_arr_kk, sigma )
        
                    # Get levels corresponding to percentages enclose
                    c_calc_kk = ContourCalc( img_arr_kk )
                    levels = c_calc_kk.get_level( contour_levels )
                
                    # Cycle through colors, skipping the color reserved for the ray data
                    if kk == 0:
                        cmap_kk = kk
                    else:
                        cmap_kk = kk + 1
                    contour_colors = [ cmap[cmap_kk], ] * len( levels )
                        
                    # Prevent invisible low-contribution components
                    alpha_min = 0.5
                    contour_alpha = ( total_weights[kk] / total_weights.max() ) * ( 1. - alpha_min ) + alpha_min
                    
                    ax.contour(
                        centers_x,
                        centers_y,
                        img_arr_kk,
                        levels,
                        colors = contour_colors,
                        linewidths = contour_linewidths,
                        alpha = contour_alpha
                    )
                    
                # Ray
                bins = all_bins['n_bins_data_2D']
                dx = all_dx['n_bins_data_2D']
                centers = all_centers['n_bins_data_2D']
                
                used_weights = weights / ( weights.sum() * dx[x_key] * dx[y_key] )
                hist2d_r, x_edges, y_edges = np.histogram2d(
                    ray_data[x_key],
                    ray_data[y_key],
                    bins = [ bins[x_key], bins[y_key], ],
                    weights = used_weights,
                )
                img_arr_r = np.transpose( hist2d_r )
                
                if pm['2D_dist_data_display'] == 'histogram':
                    ax.pcolormesh(
                        centers[x_key],
                        centers[y_key],
                        img_arr_r,
                        cmap = cmap_data,
                        shading = 'nearest',
                        norm = matplotlib.colors.LogNorm(),
                    )
                elif pm['2D_dist_data_display'] == 'contour':
                    
                    contour_centers_x = copy.copy( all_centers['n_bins_data_2D'][x_key] )
                    contour_centers_y = copy.copy( all_centers['n_bins_data_2D'][y_key] )

                    # Upsample and smooth
                    if upsample is not None:
                        img_arr_r = scipy.ndimage.zoom( img_arr_r, upsample )
                        
                        contour_centers_x = scipy.ndimage.zoom( contour_centers_x, upsample )
                        contour_centers_y = scipy.ndimage.zoom( contour_centers_y, upsample )

                    if pm['smooth_2D_dist'] is not None:
                        if upsample is not None:                                                   
                            sigma = upsample * pm['smooth_2D_dist']
                        else:
                            sigma = pm['smooth_2D_dist']
                        img_arr_r = scipy.ndimage.filters.gaussian_filter( img_arr_r, sigma )

                    # Get levels corresponding to percentages enclose
                    c_calc_kk = ContourCalc( img_arr_r )
                    levels = c_calc_kk.get_level( contour_levels + [ -1, ], False )
                    
                    ax.contourf(
                        contour_centers_x,
                        contour_centers_y,
                        img_arr_r,
                        levels,
                        cmap = cmap_data,
                        zorder = -100,
                    )

                ax.set_xlim( bins[x_key][0], bins[x_key][-1] )
                ax.set_ylim( bins[y_key][0], bins[y_key][-1] )

                if logscale[x_key]:
                    ax.set_xscale( 'log' )
                if logscale[y_key]:
                    ax.set_yscale( 'log' )

                x_label = labels[x_key]
                y_label = labels[y_key]

            if subplotspec.is_last_row():
                ax.set_xlabel( x_label, fontsize=16 )
            if subplotspec.is_first_col():
                ax.set_ylabel( y_label, fontsize=16 )
                
    # Add a legend
    h, l = ax_dict['vlos'].get_legend_handles_labels()
    ax_dict['legend'].legend( h, l, loc='lower left' )
    ax_dict['legend'].axis( 'off' )

    # Save
    savedir = './figures/sample2/comparison'
    if pm['weighting'] == 'density':
        savedir = os.path.join( savedir, 'density_weighting' )
    os.makedirs( savedir, exist_ok=True )
    savefile = 'sightline_{}.png'.format( os.path.basename( sl_fps[i] ) )
    save_fp = os.path.join( savedir, savefile )
    plt.savefig( save_fp, bbox_inches='tight' )
    
    plt.close()