# Plot of posterior samples

_Kara Ponder (SLAC-->?), Emille Ishida (Clermont-Ferrand), Alex Malz (GCCL@RUB)_

plagiarized from `Combination_plots.ipynb`

In [None]:
from collections import OrderedDict
import glob
import gzip
import numpy as np
import os
import pandas as pd
import pickle as pkl
import scipy.stats as sps
import sys

## Kara's plotting code

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.cm as cmx
from mpl_toolkits.axes_grid1 import make_axes_locatable

# import resspect.cosmo_metric_utils as cmu

In [None]:
import pylab
from mpl_toolkits.axes_grid1 import ImageGrid
from mpl_toolkits.axes_grid1.inset_locator import inset_axes


In [None]:
all_shapes = {'SNIa-91bg': 'o',
              'SNIax': 's',
              'SNII': 'd',
              'SNIbc': 'X',
              'SLSN-I': 'v',
              'AGN': '^',
              'TDE': '<',
              'KN': '>',
              'CART': 'v'}

In [None]:
# Color map
rainbow = cm = plt.get_cmap('plasma_r')
cNorm  = colors.LogNorm(vmin=1, vmax=52) #colors.Normalize(vmin=0, vmax=50)
scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=rainbow)
color_map = scalarMap.to_rgba(np.arange(1, 52))

## prep for data

In [None]:
# DDF summary on the COIN server:

# file_extension =  'ddf'#'wfd'
file_extensions = {'ddf': 'DDF', 'wfd': 'WFD'}
ktot = 5
kglob = ''
# field = 'DDF'#'WFD'

# if field == 'WFD':
#     k = ''

In [None]:
def get_cases(field):
    if field == 'DDF':
        dirname = '/media/RESSPECT/data/PLAsTiCC/for_metrics/ddf/emille_samples/'
        cases = os.listdir(dirname)
        cases.remove('random1500.csv')
        cases.remove('random6000.csv')
        cases.remove('fiducial1500.csv')
        cases.remove('fiducial6000.csv')
        cases.remove('perfect6000.csv')
        cases.remove('perfect1500.csv')
        cases.remove('all_DDF.csv')
#         cases.remove('perfect3000.csv')
#         cases.remove('random3000.csv')
    elif field == 'WFD':
        dirname = '/media/RESSPECT/data/PLAsTiCC/for_metrics/wfd/emille_samples'# + str(k) + '/'
        cases = os.listdir(dirname)
#         print(cases)
        cases.remove('M0DIF')
        cases.remove('fitres')
        cases.remove('stan_summary')
        cases.remove('all_WFD.csv')
        cases.remove('random1500.csv')
        cases.remove('random6000.csv')
        cases.remove('fiducial1500.csv')
        cases.remove('fiducial6000.csv')
        cases.remove('perfect6000.csv')
        cases.remove('perfect1500.csv')
        cases.remove('perfect3000_IX.csv')
        cases.remove('perfect3000_I.csv') 
        cases.remove('perfect3000_II.csv')
        cases.remove('perfect3000_III.csv')
        cases.remove('perfect3000_IV.csv')
        cases.remove('perfect3000_V.csv')
        cases.remove('perfect3000_VI.csv')
        cases.remove('perfect3000_VII.csv')
        cases.remove('perfect3000_VIII.csv')
        cases.remove('perfect3000_0.csv')
#         cases.remove('perfect3000.csv')
#         cases.remove('random3000.csv')
#         cases.remove('fiducial3000.csv')
    if '.ipynb_checkpoints' in cases:
        cases.remove('.ipynb_checkpoints')
    return(cases, dirname)

In [None]:
cases, dirnames = {}, {}
for file_extension in file_extensions:
    cases[file_extension], dirnames[file_extension] = get_cases(file_extensions[file_extension])

In [None]:
cases['wfd']

In [None]:
def make_remap_dict(file_extension):
    if 'wfd' in file_extension:
        remap_dict = OrderedDict({
                              'perfect3000': 'Perfect', 
                              'fiducial3000': 'Fiducial', 
#                               'random3000fail2998': 'Random',
                              'random3000': 'Random',
                              'all_objs_survived_SALT2_DDF' : 'All SALT',
                              'all_objs_survived_SALT2_WFD': 'All SALT',
                              '50SNIa50SNII': 'SN-II 50',
                              '68SNIa32SNII': 'SN-II 32',
                              '72SNIa28SNII': 'SN-II 28',
                              '75SNIa25SNII': 'SN-II 25', 
                              '90SNIa10SNII': 'SN-II 10',
                              '95SNIa5SNII': 'SN-II 5',
                              '98SNIa2SNII': 'SN-II 2',
                              '99SNIa1SNII': 'SN-II 1',
                              '50SNIa50SNIbc': 'SN-Ibc 50',
                              '68SNIa32SNIbc': 'SN-Ibc 32',
                              '75SNIa25SNIbc': 'SN-Ibc 25',
                              '83SNIa17SNIbc': 'SN-Ibc 17',
                              '90SNIa10SNIbc': 'SN-Ibc 10',
                              '95SNIa5SNIbc': 'SN-Ibc 5',
                              '98SNIa2SNIbc': 'SN-Ibc 2',
                              '99SNIa1SNIbc': 'SN-Ibc 1',
                              '50SNIa50SNIax': 'SN-Iax 50',
                              '68SNIa32SNIax': 'SN-Iax 32',
                              '75SNIa25SNIax': 'SN-Iax 25',
                              '86SNIa14SNIax': 'SN-Iax 14',
                              '90SNIa10SNIax': 'SN-Iax 10',
                              '94SNIa6SNIax': 'SN-Iax 6',
                              '95SNIa5SNIax': 'SN-Iax 5',
                              '97SNIa3SNIax': 'SN-Iax 3',
                              '98SNIa2SNIax': 'SN-Iax 2',
                              '99SNIa1SNIax': 'SN-Iax 1',
                              '71SNIa29SNIa-91bg': 'SN-Ia-91bg 29',
                              '75SNIa25SNIa-91bg': 'SN-Ia-91bg 25',
                              '90SNIa10SNIa-91bg': 'SN-Ia-91bg 10',
                              '95SNIa5SNIa-91bg': 'SN-Ia-91bg 5',
                              '98SNIa2SNIa-91bg': 'SN-Ia-91bg 2',
                              '99SNIa1SNIa-91bg': 'SN-Ia-91bg 1',
                              '99.8SNIa0.2SNIa-91bg': 'SN-Ia-91bg 0.2',
                              '57SNIa43AGN': 'AGN 43',
                              '75SNIa25AGN': 'AGN 25',
                              '90SNIa10AGN': 'AGN 10',
                              '94SNIa6AGN': 'AGN 6',
                              '95SNIa5AGN': 'AGN 5',
                              '98SNIa2AGN': 'AGN 2',
                              '99SNIa1AGN': 'AGN 1',
                              '99.9SNIa0.1AGN': 'AGN 0.1',
                              '83SNIa17SLSN-I': 'SLSN-I 17',
                              '90SNIa10SLSN-I': 'SLSN-I 10',
                              '95SNIa5SLSN-I': 'SLSN-I 5',
                              '98SNIa2SLSN-I': 'SLSN-I 2',
                              '99SNIa1SLSN-I': 'SLSN-I 1',
                              '99SNIa1SLSN': 'SLSN 1',
                              '99.9SNIa0.1SLSN': 'SLSN-I 0.1',
                              '95SNIa5TDE': 'TDE 5',
                              '98SNIa2TDE': 'TDE 2',
                              '99SNIa1TDE': 'TDE 1',
                              '99.6SNIa0.4TDE': 'TDE 0.4',
                              '99.1SNIa0.9CART': 'CART 0.9',
                              '99.7SNIa0.3CART': 'CART 0.3'
                  })
    else:
        remap_dict = OrderedDict({
                          'perfect3000': 'Perfect', 
                          'fiducial3000': 'Fiducial', 
#                           'random3000fail2998': 'Random',
                          'random3000': 'Random',
                          'all_objs_survived_SALT2_DDF' : 'All SALT',
                          'all_objs_survived_SALT2_WFD': 'All SALT',
                          '50SNIa50SNII': 'SN-II 50',
                          '68SNIa32SNII': 'SN-II 32',
                          '72SNIa28SNII': 'SN-II 28',
                          '75SNIa25SNII': 'SN-II 25', 
                          '90SNIa10SNII': 'SN-II 10',
                          '95SNIa5SNII': 'SN-II 5',
                          '98SNIa2SNII': 'SN-II 2',
                          '99SNIa1SNII': 'SN-II 1',
                          '50SNIa50SNIbc': 'SN-Ibc 50',
                          '68SNIa32SNIbc': 'SN-Ibc 32',
                          '75SNIa25SNIbc': 'SN-Ibc 25',
                          '83SNIa17SNIbc': 'SN-Ibc 17',
                          '90SNIa10SNIbc': 'SN-Ibc 10',
                          '92SNIa8SNIbc': 'SN-Ibc 8',
                          '95SNIa5SNIbc': 'SN-Ibc 5',
                          '98SNIa2SNIbc': 'SN-Ibc 2',
                          '99SNIa1SNIbc': 'SN-Ibc 1',
                          '50SNIa50SNIax': 'SN-Iax 50',
                          '68SNIa32SNIax': 'SN-Iax 32',
                          '75SNIa25SNIax': 'SN-Iax 25',
                          '86SNIa14SNIax': 'SN-Iax 14',
                          '90SNIa10SNIax': 'SN-Iax 10',
                          '91SNIa9SNIax': 'SN-Iax 9',
                          '94SNIa6SNIax': 'SN-Iax 6',
                          '95SNIa5SNIax': 'SN-Iax 5',
                          '97SNIa3SNIax': 'SN-Iax 3',
                          '98SNIa2SNIax': 'SN-Iax 2',
                          '99SNIa1SNIax': 'SN-Iax 1',
                          '99.1SNIa0.9CART': 'CART 0.9',
                          '99.7SNIa0.3CART': 'CART 0.3',
                          '71SNIa29SNIa-91bg': 'SN-Ia-91bg 29',
                          '75SNIa25SNIa-91bg': 'SN-Ia-91bg 25',
                          '90SNIa10SNIa-91bg': 'SN-Ia-91bg 10',
                          '95SNIa5SNIa-91bg': 'SN-Ia-91bg 5',
                          '98SNIa2SNIa-91bg': 'SN-Ia-91bg 2',
                          '99SNIa1SNIa-91bg': 'SN-Ia-91bg 1',
                          '99.8SNIa0.2SNIa-91bg': 'SN-Ia-91bg 0.2',
                          '57SNIa43AGN': 'AGN 43',
                          '75SNIa25AGN': 'AGN 25',
                          '90SNIa10AGN': 'AGN 10',
                          '94SNIa6AGN': 'AGN 6',
                          '95SNIa5AGN': 'AGN 5',
                          '98SNIa2AGN': 'AGN 2',
                          '99SNIa1AGN': 'AGN 1',
                          '99.9SNIa0.1AGN': 'AGN 0.1',
                          '83SNIa17SLSN-I': 'SLSN-I 17',
                          '90SNIa10SLSN-I': 'SLSN-I 10',
                          '95SNIa5SLSN-I': 'SLSN-I 5',
                          '98SNIa2SLSN-I': 'SLSN-I 2',
                          '99SNIa1SLSN-I': 'SLSN-I 1',
                          '99SNIa1SLSN': 'SLSN 1',
                          '99.9SNIa0.1SLSN': 'SLSN-I 0.1',
                          '95SNIa5TDE': 'TDE 5',
                          '98SNIa2TDE': 'TDE 2',
                          '99SNIa1TDE': 'TDE 1',
                          '99.6SNIa0.4TDE': 'TDE 0.4',
              })
    return(remap_dict)

In [None]:
remap_dicts = {}
for file_extension in file_extensions:
#     print(file_extension)
    thing = make_remap_dict(file_extensions[file_extension])
    tempdict = {}
    for case in cases[file_extension]:
        if case[:-4] in thing.keys():
            tempdict[case[:-4]] = thing[case[:-4]]
        else:
            print(case)
    remap_dicts[file_extension] = tempdict#{thing[case[:-4]] for case in cases[file_extension]}

In [None]:
# Mapping the percent contaminated to the colormap.
## size corresponds to remap_dict
def make_color_nums(file_extension):
    if 'wfd' in file_extension:
        color_num = np.array([1, 1, 1, 1, 1, 1,                    # Special
                           50, 32, 28, 25, 10, 5, 2, 1,   # II
                           50, 32, 25, 17, 10, 5, 2, 1,               # Ibc
                           50, 32, 25, 14, 10, 6, 5, 3, 2, 1,         # Iax
                           29, 25, 10, 5, 2, 1, 1,                          # 91bg
                           43, 25, 10, 6, 5, 2, 1, 1,                      # AGN
                           17, 10, 5, 2, 1, 1, 1,                            # SLSN
                           5, 2, 1, 1,                            # TDE
                           1, 1,                           # CART
                          ]) #+ 1                    
    else:
        color_num = np.array([1, 1, 1, 1, 1, 1,                    # Special
                           50, 32, 28, 25, 10, 5, 2, 1,   # II
                           50, 32, 25, 17, 10, 8, 5, 2, 1,               # Ibc
                           50, 32, 25, 14, 10, 9, 6, 5, 3, 2, 1,         # Iax
                           1, 1,                           # CART
                           29, 25, 10, 5, 2, 1, 1,                          # 91bg
                           43, 25, 10, 6, 5, 2, 1, 1,                      # AGN
                           17, 10, 5, 2, 1, 1, 1,                            # SLSN
                           5, 2, 1, 1,                            # TDE
                          ]) #+ 1   
    return(color_num)

In [None]:
color_nums = {}
for file_extension in file_extensions:
    color_nums[file_extension] = make_color_nums(file_extensions[file_extension])

In [None]:
# Color map
rainbow = cm = plt.get_cmap('plasma_r')
cNorm  = colors.LogNorm(vmin=1, vmax=52) #colors.Normalize(vmin=0, vmax=50)
scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=rainbow)
color_map = scalarMap.to_rgba(np.arange(1, 52))

In [None]:
# dist_loc_base = '/media/RESSPECT/data/PLAsTiCC/for_metrics/' + file_extension + '/distances/omprior_0.01_flat/emille_samples/*' #mu_photoIa_plasticc*'

# if 'wfd' in file_extension:
#     table_loc = '/media2/RESSPECT2/data/posteriors_wfd/omprior_0.01_flat/summary_cases_omprior_0.01_flat_emille.csv'
# else:
#     table_loc = '/media2/RESSPECT2/data/posteriors_ddf/omprior_0.01_flat/summary_cases_emille.csv' 

# dist_loc_files = glob.glob(dist_loc_base)

In [None]:
# df = pd.read_csv(table_loc)

## calculate the curve(s)

KDE for each set of posterior samples

In [None]:
eps = 2. * sys.float_info.min

def safe_log(arr, threshold=eps):
    """
    Takes the natural logarithm of an array that might contain zeros.

    Parameters
    ----------
    arr: ndarray, float
        array of values to be logged
    threshold: float, optional
        small, positive value to replace zeros and negative numbers

    Returns
    -------
    logged: ndarray
        logged values, with small value replacing un-loggable values
    """
    arr = np.asarray(arr)
    arr[arr < threshold] = threshold
    logged = np.log(arr)
    return logged

def make_grid(x, y, x_ngrid=100, y_ngrid=100):
    x_min = x.min()#-1.2
    x_max = x.max()#-0.8
    y_min = y.min()#0.2
    y_max = y.max()#0.4

    x_grid, y_grid = np.mgrid[x_min:x_max:x_ngrid*1.j, y_min:y_max:y_ngrid*1.j]
    x_vec, y_vec = x_grid[:, 0], y_grid[0, :]
    dx = (x_max - x_min) / (x_ngrid - 1)
    dy = (y_max - y_min) / (y_ngrid - 1)

    return(((x_min, y_min), (x_max, y_max)), (x_grid, y_grid), (x_vec, y_vec), (dx, dy))

def make_kde(Xgrid, Ygrid, Xsamps, Ysamps, to_log=False, save=None, one_d=True):
    if not one_d:
        positions = np.vstack([Xgrid.ravel(), Ygrid.ravel()])
        values = np.vstack([Xsamps, Ysamps])
        kernel = sps.gaussian_kde(values, bw_method='scott')
        Z = np.reshape(kernel(positions).T, Xgrid.shape)
    else:
        positions = Xgrid.T[0]
        values = Xsamps
        kernel = sps.gaussian_kde(values, bw_method='scott')
        Z = kernel(positions)
    
    if to_log:
        return safe_log(Z)
    else:
        return Z
#     if save is not None:
# TODO: normalize up here before log!https://us02web.zoom.us/j/204149456?pwd=VWV6YnQvOHU3ajh6blNWa2hLb2FEUT09#success

In [None]:

# alloutputs = pd.DataFrame(columns=['path', 'KLD'])
#     # make reference sample
# with gzip.open(fullpath+refname) as reffn:
#     flatref = pd.read_csv(reffn)
# [w_ref, Omm_ref] = [flatref['w'], flatref['om']]
# ref_extrema, ref_grids, ref_vecs, ref_ds = make_grid(w_ref, Omm_ref)
# (w_vec, Omm_vec) = ref_vecs
# (dw, dOmm) = ref_ds
# ((xmin, ymin), (xmax, ymax)) = ref_extrema
# (w_grid, Omm_grid) = ref_grids
# d_ref = {'w': dw, 'Omm': dOmm}
# grid_ref = {'w': w_grid, 'Omm': Omm_grid}
# kde_ref = make_kde(w_grid, Omm_grid, w_ref, Omm_ref, one_d=True, to_log=True)

In [None]:
def get_posteriors(field, k, casename, nsn, withlowz=True):
    case = casename+str(nsn)
    filename = 'chains_'+case
    if withlowz:
        filename = filename+'_lowz_withbias'
    if field == 'ddf':
        if k == '':
            path_pre = '/media/RESSPECT/data/PLAsTiCC/for_metrics/ddf/posteriors/samples_emille/'
            ext = '.csv.gz'
        else:
            path_pre = '/media/RESSPECT/data/PLAsTiCC/for_metrics/ddf/emille_samples'+str(k)+'/posteriors/'
            ext = '.pkl'
    elif field == 'wfd':
        path_pre = '/media/RESSPECT/data/PLAsTiCC/for_metrics/wfd/posteriors/samples_emille'+str(k)+'/'
        ext = '.csv.gz'
    if ext == '.csv.gz':
        with gzip.open(path_pre+filename+ext) as sampfile:
            sampdata = pd.read_csv(sampfile)
    elif ext == '.pkl':
        with open(path_pre+filename+ext, 'rb') as sampfile:
            sampdata = pkl.load(sampfile)
    return([sampdata['w'], sampdata['om']])

In [None]:
null_cases = ['perfect', 'random', 'fiducial']
ktot = 6
kmin = 0
samp_sizes = [1500, 3000, 6000]
ngrid = 100

In [None]:
# temporary until all runs finish!
# file_extensions = {'wfd': 'WFD'}
kmin = 1

In [None]:
# path_pre = dirname[:46]
# path_post = 'posteriors/samples_emille'+str(k)+'/'
# # refname = 'chains_perfect3000_lowz_withbias.csv.gz'

# fullpath = path_pre + path_post

outdata = {}
for field in file_extensions:
    outdata[field] = {}
    for casename in null_cases:
        outdata[field][casename] = np.empty((ktot, len(samp_sizes), 2, ngrid))
        for k in range(kmin, ktot):
            if k == 0:
                k = ''
            for i, nsn in enumerate(samp_sizes):
#     case = casename+'3000.csv'
#     samppath = fullpath+'chains_'+case[:-4]+'_lowz_withbias.csv.gz'
#     with gzip.open(samppath) as sampfile:
#         sampdata = pd.read_csv(sampfile)
                [w_comp, Omm_comp] = get_posteriors(field, k, casename, nsn, withlowz=False)#[sampdata['w'], sampdata['om']]
                comp_extrema, comp_grids, comp_vecs, comp_ds = make_grid(w_comp, Omm_comp)
                (w_grid, Omm_grid) = comp_grids
                kde_comp = make_kde(w_grid, Omm_grid, w_comp, Omm_comp, one_d=True, to_log=True)
                outdata[field][casename][k][i] = np.array([w_grid.T[0], kde_comp])
with open('default_kdes.pkl', 'wb') as outfile:
    pkl.dump(outdata, outfile)

In [None]:
outdata = {}
for field in file_extensions:
    outdata[field] = {}
    for casename in cases[field]:
        outdata[field][casename[:-4]] = np.empty((2, ngrid))
        k = ''
#             for i, nsn in enumerate(samp_sizes):
        nsn = ''
        [w_comp, Omm_comp] = get_posteriors(field, k, casename[:-4], nsn, withlowz=True)#[sampdata['w'], sampdata['om']]
        comp_extrema, comp_grids, comp_vecs, comp_ds = make_grid(w_comp, Omm_comp)
        (w_grid, Omm_grid) = comp_grids
        kde_comp = make_kde(w_grid, Omm_grid, w_comp, Omm_comp, one_d=True, to_log=True)
        outdata[field][casename[:-4]] = np.array([w_grid.T[0], kde_comp])
with open('testcase_kdes.pkl', 'wb') as outfile:
    pkl.dump(outdata, outfile)

In [None]:
# for casenum in [1500, 3000, 6000]:
#     case = 'perfect'+str(casenum)+'.csv'
#     samppath = fullpath+'chains_'+case[:-4]+'_lowz_withbias.csv.gz'
#     with gzip.open(samppath) as sampfile:
#         sampdata = pd.read_csv(sampfile)
#         [w_comp, Omm_comp] = [sampdata['w'], sampdata['om']]
#         comp_extrema, comp_grids, comp_vecs, comp_ds = make_grid(w_comp, Omm_comp)
#         (w_grid, Omm_grid) = comp_grids
#         kde_comp = make_kde(w_grid, Omm_grid, w_comp, Omm_comp, one_d=True, to_log=True)
#         outdata[str(casenum)] = [w_grid, kde_comp]
# with open(field+'kdes_perfect.pkl', 'wb') as outfile:
#     pkl.dump(outdata, outfile)

## make plot(s)

In [None]:
def_colors = {'perfect': 'k', 'random': 'tab:red', 'fiducial': 'tab:blue'}
def_styles = {'1500': ':', '3000': '-', '6000': '--'}#{'DDF': '-', 'WFD': '--'}
# def_lowz = {'withbias': , 'nobias':}

with open('default_kdes.pkl', 'rb') as infile:
    indata = pkl.load(infile)

for field in file_extensions:
    for casename in null_cases:
        plt.scatter([0], [0], label=casename, color=def_colors[casename])
        for i, nsn in enumerate(samp_sizes):
            for k in range(ktot):
                w_grid, kde_comp = indata[field][casename][k][i]#[w_grid, kde_comp] = indata[casename]
                plt.plot(w_grid, np.exp(kde_comp),# label=field+casename, 
                 linestyle=def_styles[str(nsn)], color=def_colors[casename], alpha=0.5)
    for nsn in samp_sizes:
        plt.plot([0], [0], label=str(nsn), 
                 linestyle=def_styles[str(nsn)], color='tab:green') 
    plt.title(file_extensions[field])
    plt.legend(loc='upper right')#, ncol=2)
    # plt.title(field+k)
    plt.xlabel(r'$w$')
    # plt.yticklabels([])
    plt.xlim(-1.15, -0.85)
    plt.savefig(file_extensions[field]+'dists_null.png')
    plt.show()

todo: investigate the runs that are flat KDEs

todo: also with and without bias of lowz sample

In [None]:
rates, contaminants = {}, {}
for field in file_extensions:
    rate, contaminant = {}, {}
    for key in remap_dicts[field]:
        postsplit = remap_dicts[field][key].split()
        if len(postsplit) > 1:
            name = postsplit[0]
            perc = float(postsplit[-1])
#         rate[name] = perc
            rate[key] = perc
            contaminant[key] = name
    rates[field] = rate
    contaminants[field] = contaminant

In [None]:
plt.hist(rates['ddf'].values(), bins=25, alpha=0.5)
plt.hist(rates['wfd'].values(), bins=25, alpha=0.5)

todo: automate dividing into panels

In [None]:
cutoffs = [0., 1., 2., 5., 7.5, 15., 50.]

panel_groups = {}
for field in file_extensions:
    panel_groups[field] = {j: [] for j in range(6)}
#     print(field)
    for i, casefn in enumerate(rates[field]):
        casename = casefn#[:-4]
        rate = rates[field][casename]
#         for j, cutoff in enumerate(cutoffs[:-1]):
#             if rate > cutoffs[j] and rate < :
#                 panel_groups[field][0].append(casename)
        if rate > 0. and rate < 1.:
            panel_groups[field][0].append(casename)
#             print((casename, rates[field][casename], 0))
        elif rate >= 1. and rate < 2.:
            panel_groups[field][1].append(casename)
#             print((casename, rates[field][casename], 1))
        elif rate >= 2. and rate < 5.:
            panel_groups[field][2].append(casename)
#             print((casename, rates[field][casename], 1))
        elif rate >= 5. and rate < 7.5:
            panel_groups[field][3].append(casename)
#             print((casename, rates[field][casename], 5))
        elif rate >= 7.5 and rate <= 15.:
            panel_groups[field][4].append(casename)
#             print((casename, rates[field][casename], 10))  
        elif rate >= 15. and rate <= 50.:
            panel_groups[field][5].append(casename)
#             print((casename, rates[field][casename], 25))    
#         else:
#             print((casename, rates[field][casename], 'big'))
print(panel_groups)

In [None]:
ddf_set = set(contaminants['ddf'].values())
wfd_set = set(contaminants['wfd'].values())
all_contaminants = set.union(ddf_set, wfd_set)

color_list = OrderedDict({contaminant: plt.cm.tab10(i) for i, contaminant in enumerate(all_contaminants)})

contaminant_colors = {}
for field in file_extensions:
    contaminant_colors[field] = {}
    for i, contaminant in enumerate(contaminants[field]):
        contaminant_colors[field][contaminant] = color_list[contaminants[field][contaminant]]

In [None]:
# def_colors = {'perfect': 'k', 'random': 'tab:red', 'fiducial': 'tab:blue'}
# def_styles = {'1500': ':', '3000': '-', '6000': '--'}#{'DDF': '-', 'WFD': '--'}
# def_lowz = {'withbias': , 'nobias':}

with open('testcase_kdes.pkl', 'rb') as infile:
    indata = pkl.load(infile)
    

for field in file_extensions:
    fig = pylab.figure(figsize=(15, 10))
    bigAxes = pylab.axes(frameon=False)     # hide frame
    bigAxes.set_xticks([])                        # don't want to see any ticks on this axis
    bigAxes.set_yticks([])
#     bigAxes.set_xlabel(r'$-1.3 < w < -0.9$', fontsize=20)
    bigAxes.set_title(file_extensions[field], fontsize=20)
    numrows=2
    numcols=3
#     fig, ax = plt.subplots(2, 3, figsize=(15, 10))
#     fig.suptitle(file_extensions[field])
    for i in range(len(panel_groups[field])):
        per_panel_contaminants = [contaminants[field][panel_groups[field][i][j]] 
                                  for j in range(len(panel_groups[field][i]))]
        print(per_panel_contaminants)
        ax = fig.add_subplot(numrows,numcols,i+1)
#         ax.spines['right'].set_visible(False)
#         ax.spines['top'].set_visible(False)
        position = ax.get_position()
        position.x0 += 0.01
        position.y0 += 0.02
        position.x1 += 0.01
        position.y1 += 0.02
        ax.set_position(position)
        stylecount = 0
        for j, casename in enumerate(panel_groups[field][i]):
#             if 
            w_grid, kde_comp = indata[field][casename]#[w_grid, kde_comp] = indata[casename]
            ax.plot(w_grid, np.exp(kde_comp), alpha=0.5, color=contaminant_colors[field][casename], label=casename) 
#         plt.title(field)
        ax.legend(fontsize='large')#loc='upper right')#, ncol=2)
        ax.set_xlabel(r'$w$', fontsize=20)
        ax.set_ylim(0., 40.)
        ax.set_yticks([])
        ax.set_yticklabels([])
        ax.set_xlim(-1.3, -0.9)
        ax.set_xticks([-1.2, -1.1, -1.])
#         ax.set_xticklabels([])
#     plt.savefig(file_extensions[field]+'dists_null.png')
    
    fig.subplots_adjust(wspace=0., hspace=0.)
    pylab.savefig(file_extensions[field]+'combos.png',format='png', bbox_inches='tight', pad_inches=0, dpi=250)
#     fig.show()

TODO: polish these for paper
- linestyle for contamination rate if more than one with same contaminant per panel

In [None]:
# def_classifiers = {'perfect': 'k', 'random': 'tab:red', 'fiducial': 'tab:blue'}
# # def_nums = {'1500': 2, '3000': 1, '6000': 0.5}
# def_styles = {'1500': ':', '3000': '-', '6000': '--'}#{'DDF': '-', 'WFD': '--'}
# # def_lowz = {'withbias': , 'nobias':}
# #182 lowz sample z<0.11

In [None]:
# for field in ['DDF']:#, 'WFD']:
#     with open(field+'kdes_perfect.pkl', 'rb') as infile:
#         indata = pkl.load(infile)

#     for casenum in [1500, 3000, 6000]:
#         [w_grid, kde_comp] = indata[str(casenum)]
#         plt.plot(w_grid.T[0], np.exp(kde_comp), label=field+'perfect'+str(casenum),
#                  linewidth=def_nums[str(casenum)], color=def_classifiers['perfect'], linestyle=def_fields[field])#, 
# plt.legend()
# # plt.title(field+k)
# plt.xlabel(r'$w$')
# # plt.yticklabels([])
# plt.savefig('dists_perfect.png')
# plt.show()

In [None]:
# fig, axes = plt.subplots(1, 6, figsize=(22,10), sharey=True)

# # wfit
# ax1 = axes[0]
# # Bayes
# ax2 = axes[1]

# ax1.axvline(-1, color='c', ls='--')
# ax2.axvline(-1, color='c', ls='--')

# ax1.axvline(df['wfit_w_lowz'].loc[df['case'] == 'perfect3000'].values, color='k', ls='-.')
# ax1.axvspan(df['wfit_w_lowz'].loc[df['case'] == 'perfect3000'].values-df['wfit_wsig_lowz'].loc[df['case'] == 'perfect3000'].values, 
#             df['wfit_w_lowz'].loc[df['case'] == 'perfect3000'].values+df['wfit_wsig_lowz'].loc[df['case'] == 'perfect3000'].values, 
#             alpha=0.15, color='grey')

# ax2.axvline(df['stan_w_lowz'].loc[df['case'] == 'perfect3000'].values, color='k', ls='-.')
# ax2.axvspan(df['stan_w_lowz'].loc[df['case'] == 'perfect3000'].values-df['stan_wsig_lowz'].loc[df['case'] == 'perfect3000'].values, 
#             df['stan_w_lowz'].loc[df['case'] == 'perfect3000'].values+df['stan_wsig_lowz'].loc[df['case'] == 'perfect3000'].values, 
#             alpha=0.15, color='grey')

# # Fisher
# # percent different
# ax3 = axes[2]
# ax3.axvline(0, color='k', ls='-.')
# df_fisher = pd.read_csv('/media/RESSPECT/data/PLAsTiCC/for_metrics/'+file_extension+'/distances/omprior_0.01_flat/emille_samples/stan_input_salt2mu_lowz_withbias_perfect3000.csv')
# sig_perf = cmu.fisher_results(df_fisher['z'].values, df_fisher['muerr'].values)[0]


# # Wasserstein
# ax4 = axes[4]
# ax4.axvline(0, color='k', ls='-.')

# # FOM3
# ax5 = axes[3]
# ax5.axvline(df['fom3'].loc[df['case'] == 'perfect3000'].values, color='k', ls='-.')

# # KLD
# ax6 = axes[5]

# i = 0
# tick_lbls = []
# i_list = []
# for j, (a, c) in enumerate(zip(remap_dict, color_nums)):
#     try:
#         # wfit
#         wfw = df['wfit_w_lowz'].loc[df['case'] == a].values
#         wfw_sig = df['wfit_wsig_lowz'].loc[df['case'] == a].values
#         class_ = str.split(remap_dict[a])[0]
        
#         # Fisher
#         file = glob.glob(dist_loc_base + a + '.csv')
#         df_ = pd.read_csv(str(file[0]))
#         sig = cmu.fisher_results(df_['z'].values, df_['muerr'].values)[0]
        
#         # Wasserstein
#         wsd = df['WassersteinDistanceMedian'].loc[df['case'] == a].values
        
#         # fom3
#         fom3 = df['fom3'].loc[df['case'] == a].values
        
#         # KLD
#         kld = df['KLD'].loc[df['case'] == a].values

#         if '91bg' in class_:
#             class_ = 'SNIa-91bg'
#         else:
#             class_ = class_.replace('-', '')

#         bad_data=False
#         if wfw[0] < -2.2:
#             wfw[0] = -1.4
#             bad_data=True
#             xuplims=[-1.5]
        
#         if 'ddf' in file_extension:
#             if 'fiducial' in a:
#                 mfc = 'tab:blue'
#             elif 'random' in a:
#                 mfc = 'tab:red'
#             elif 'perfect' in a:
#                 mfc = 'k'
#             else:
#                 mfc = color_map[c]
#         if 'wfd' in file_extension:
#             mfc = "none"

#         if 'fiducial' in a:
#             if bad_data:
#                 ax1.errorbar(wfw, [-i], xerr=[0.03], marker='*',color='tab:blue',
#                              xuplims=xuplims, markersize=12, mfc=mfc)
#             else:
#                 ax1.plot(wfw, -i, '*', color='tab:blue', ms=12, mfc=mfc)
#                 ax1.plot([wfw - wfw_sig, wfw + wfw_sig], [-i, -i], "|-", color='tab:blue', ms=10)
#         elif 'random' in a:
#             if bad_data:
#                 ax1.errorbar(wfw, [-i], xerr=[0.03], marker='*',color='tab:red',
#                              xuplims=xuplims, markersize=12, mfc=mfc)
#             else:
#                 ax1.plot(wfw, -i, '*', color='tab:red', ms=12, mfc=mfc)
#                 ax1.plot([wfw - wfw_sig, wfw + wfw_sig], [-i, -i], "|-", color='tab:red', ms=10)
#         elif 'perfect' in a:
#             ax1.plot(wfw, -i, '*', color='k', ms=12, mfc=mfc)
#             ax1.plot([wfw - wfw_sig, wfw + wfw_sig], [-i, -i], "|-", color='k', ms=10 )
#         elif 'all_objs_survived' in a:
#             ax1.plot(wfw, -i, '*', color='seagreen', ms=12, mfc=mfc)
#             ax1.plot([wfw - wfw_sig, wfw + wfw_sig], [-i, -i], "|-", color='seagreen', ms=10 )
#         else:
#             if bad_data:
#                 ax1.errorbar(wfw, [-i], xerr=[0.03], marker=all_shapes[class_],color=color_map[c],
#                              xuplims=xuplims, markersize=10)
#             else:
#                 ax1.plot(wfw, -i, color=color_map[c], ms=10, marker=all_shapes[class_], mfc=mfc)
#                 ax1.plot([wfw - wfw_sig, wfw + wfw_sig], [-i, -i], "|-", color=color_map[c], ms=10)

#         # Stan/Bayes
#         bw = df['stan_w_lowz'].loc[df['case'] == a].values
#         bw_sig = df['stan_wsig_lowz'].loc[df['case'] == a].values
        
#         bad_data=False
#         if bw[0] < -2.2:
#             bw[0] = -1.4
#             bad_data=True
#             xuplims=[-1.5]

#         if 'fiducial' in a:
#             if bad_data:
#                 ax2.errorbar(bw, [-i], xerr=[0.03], marker='*',color='tab:blue',
#                              xuplims=xuplims, markersize=12, mfc=mfc)
#             else:
#                 ax2.plot(bw, -i, '*', color='tab:blue', ms=12, mfc=mfc)
#                 ax2.plot([bw - bw_sig, bw + bw_sig], [-i, -i], "|-", color='tab:blue', ms=10)
#                 ax3.plot((sig[1]-sig_perf[1])/sig_perf[1], -i, color='tab:blue', marker='*', ms=10, mfc=mfc)
#                 ax4.plot(wsd, -i, '*', color='tab:blue', ms=12, mfc=mfc)
#                 ax5.plot(fom3, -i, '*', color='tab:blue', ms=12, mfc=mfc)
#                 ax6.semilogx(kld, -i, '*', color='tab:blue', ms=12, mfc=mfc)
#         elif 'random' in a:
#             if bad_data:
#                 ax2.errorbar(bw, [-i], xerr=[0.03], marker='*',color='tab:red',
#                              xuplims=xuplims, markersize=12, mfc=mfc)
#             else:
#                 ax2.plot(bw, -i, '*', color='tab:red', ms=12, mfc=mfc)
#                 ax2.plot([bw - bw_sig, bw + bw_sig], [-i, -i], "|-", color='tab:red', ms=10)
#                 ax3.plot((sig[1]-sig_perf[1])/sig_perf[1], -i, 'o', color='tab:red', marker='*', ms=10, mfc=mfc)
#                 ax4.plot(wsd, -i, '*', color='tab:red', ms=12, mfc=mfc)
#                 ax5.plot(fom3, -i, '*', color='tab:red', ms=12, mfc=mfc)
#                 if np.isnan(kld):
#                     # xlolims=[4e10]
#                     ax6.errorbar(1e10, [-i], xerr=[1e10], marker='*',color='tab:red',
#                                  xlolims=True, 
#                                  markersize=12, mfc=mfc)
#                 else:
#                     ax6.semilogx(kld, -i, '*', color='tab:red', ms=12, mfc=mfc)
#         elif 'perfect' in a:
#             ax2.plot(bw, -i, '*', color='k', ms=12, mfc=mfc)
#             ax2.plot([bw - bw_sig, bw + bw_sig], [-i, -i], "|-", color='k', ms=10 )
#             ax3.plot(0, -i, '*', color='k', ms=12, mfc=mfc)
#             ax4.plot(0, -i, '*', color='k', ms=12, mfc=mfc)
#             ax5.plot(fom3, -i, '*', color='k', ms=12, mfc=mfc)
#         else:
#             ax2.plot(bw, -i, color=color_map[c], ms=10, marker=all_shapes[class_], mfc=mfc)
#             ax2.plot([bw - bw_sig, bw + bw_sig], [-i, -i], "|-", color=color_map[c], ms=10)
#             ax3.plot((sig[1]-sig_perf[1])/sig_perf[1], -i, color=color_map[c], marker=all_shapes[class_], ms=10, mfc=mfc)
#             ax4.plot(wsd, -i, color=color_map[c], ms=10, marker=all_shapes[class_], mfc=mfc)
#             ax5.plot(fom3, -i, color=color_map[c], ms=10, marker=all_shapes[class_], mfc=mfc)
#             if np.isnan(kld):
#                 ax6.errorbar(1e10, [-i], xerr=[1e10],  marker=all_shapes[class_], color=color_map[c],
#                                  xlolims=True, 
#                                  markersize=10, mfc=mfc)
#             else:
#                 ax6.semilogx(kld, -i, color=color_map[c], ms=10, marker=all_shapes[class_], mfc=mfc)
            
#         tick_lbls.append(remap_dict[a])
#         i_list.append(-i)
#         i +=0.8
#         if 'random' in a or '99SNIa1' in a:
#             i_list.append(-i)
#             i += 1.1
#             tick_lbls.append('')
#     except:
#         continue
#         #print("Missing: ", a)

# tick_locs = i_list[::-1]
# #np.arange(-len(tick_lbls)+1, 1)
# ax1.set_yticks(tick_locs)
# ax1.set_yticklabels(tick_lbls[::-1], fontsize=13)

# ax1.set_ylim(i_list[-1]-0.5, i_list[0]+0.5)#-len(tick_lbls)+0.5, 0.5)


# ax1.set_xlabel(r'wfit $w$', fontsize=13)
# ax2.set_xlabel(r'Bayes $w$', fontsize=13)
# ax3.set_xlabel('FM Fractional Difference', fontsize=13)
# ax4.set_xlabel('Wasserstein Distance', fontsize=13)
# ax5.set_xlabel('FOM3', fontsize=13)
# ax6.set_xlabel('KLD', fontsize=13)
# plt.subplots_adjust(bottom=0.15, wspace=0.3) # wspace=0.05

        
# if 'ddf' in file_extension:
#     ax1.set_xlim(-1.44, -0.97)
#     ax2.set_xlim(-1.44, -0.97)
#     ax3.set_xlim(-0.025, 0.05)
#     ax4.set_xlim(-0.02, 0.22)
#     ax5.set_xlim(0, 1.05)
#     ax6.set_xlim(1e5, 3e10)
# if 'wfd' in file_extension:
#     ax1.set_xlim(-1.34, -0.77)
#     ax2.set_xlim(-1.2, -0.97)
#     ax3.set_xlim(-0.01, 0.03)
#     ax4.set_xlim(-0.02, 0.12)
#     ax5.set_xlim(-0.02, 1.05)
#     ax6.set_xlim(1e4, 1e9)

# if 'ddf' in file_extension:
#     #ticks = [-4, -13, -19, -24, -27, -30]
#     ticks = [-4, -11, -15, -21]

#     #ax1.axvspan(-2.3, -0.9, ymin=0.6, ymax=0.88, alpha=0.08, color='tab:green')
#     #ax1.axvspan(-2.3, -0.9, ymin=0.22, ymax=0.37, alpha=0.08, color='tab:green')
#     #ax1.axvspan(-2.3, -0.9, ymin=0., ymax=0.1, alpha=0.08, color='tab:green')
   
#     for ax in axes:
#         for t in ticks:
#             yticks = ax.yaxis.get_major_ticks()
#             yticks[t].set_visible(False)
#         ax.axvspan(-2.3, 5e10, ymin=0.59, ymax=0.83, alpha=0.08, color='tab:purple')
#         ax.axvspan(-2.3, 5e10, ymin=0.17, ymax=0.37, alpha=0.08, color='tab:purple')
    
#     #ax1.axvspan(-2.3, -0.9, ymin=0., ymax=0.1, alpha=0.08, color='tab:green')
#     #ax2.axvspan(-2.3, -0.9, ymin=0., ymax=0.1, alpha=0.08, color='tab:green')
    
# if 'wfd' in file_extension:
#     #ticks = [-8, -13, -19, -23, -27]
#     #ticks = [-3, -10, -15, -21, -25, -29]
#     ticks = [-3, -9, -14, -20, -24, -28]

#     #ax1.axvspan(-1.5, -0.7, ymin=0.59, ymax=0.73, alpha=0.08, color='tab:green')
#     #ax1.axvspan(-1.5, -0.7, ymin=0.25, ymax=0.35, alpha=0.08, color='tab:green')
#     #ax1.axvspan(-1.5, -0.7, ymin=0., ymax=0.07, alpha=0.08, color='tab:green')
    
#     for ax in axes:
#         for t in ticks:
#             yticks = ax.yaxis.get_major_ticks()
#             yticks[t].set_visible(False)
#         ax.axvspan(-1.5, 2e10, ymin=0.725, ymax=0.9, alpha=0.08, color='tab:purple')
#         ax.axvspan(-1.5, 2e10, ymin=0.35, ymax=0.55, alpha=0.08, color='tab:purple')
#         ax.axvspan(-1.5, 2e10, ymin=0.09, ymax=0.21, alpha=0.08, color='tab:purple')
    
# #plt.savefig('all_metrics_' + file_extension + '_lowz_20210629_emillesamp.pdf', bbox_inches='tight')


