# Plot of posterior samples

_Kara Ponder (SLAC-->?), Emille Ishida (Clermont-Ferrand), Alex Malz (GCCL@RUB)_

plagiarized from `Combination_plots.ipynb`

In [None]:
from collections import OrderedDict
import glob
import gzip
import numpy as np
import os
import pandas as pd
import pickle as pkl
import scipy.stats as sps
from matplotlib.ticker import MaxNLocator
import sys

## Kara's plotting code

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.cm as cmx
from mpl_toolkits.axes_grid1 import make_axes_locatable

# import resspect.cosmo_metric_utils as cmu

In [None]:
import pylab
from mpl_toolkits.axes_grid1 import ImageGrid
from mpl_toolkits.axes_grid1.inset_locator import inset_axes


In [None]:
all_shapes = {'SNIa-91bg': 'o',
              'SNIax': 's',
              'SNII': 'd',
              'SNIbc': 'X',
              'SLSN-I': 'v',
              'AGN': '^',
              'TDE': '<',
              'KN': '>',
              'CART': 't'}

In [None]:
# Color map
rainbow = cm = plt.get_cmap('plasma_r')
cNorm  = colors.LogNorm(vmin=1, vmax=52) #colors.Normalize(vmin=0, vmax=50)
scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=rainbow)
color_map = scalarMap.to_rgba(np.arange(1, 52))

## prep for data

In [None]:
# DDF summary on the COIN server:
file_extensions = {'ddf': 'DDF', 
                   'wfd': 'WFD'
                  }
ktot = 3
kglob = ''
nobjs = '3000'

In [None]:
def get_cases(field, k='', nobjs=3000):
    if k == '':
        k = '0'
    dirname = '/media/RESSPECT/data/PLAsTiCC/for_metrics/final_data3/'+field+'/results/v'+k+'/' + str(nobjs) + '/samples/'
    cases = os.listdir(dirname)
    
    if '.ipynb_checkpoints' in cases:
        cases.remove('.ipynb_checkpoints')

    return(cases, dirname)

In [None]:
cases, dirnames = {}, {}
for file_extension in file_extensions:
    cases[file_extension], dirnames[file_extension] = get_cases(file_extensions[file_extension], k=str(ktot))

In [None]:
def make_remap_dict(file_extension):
    if 'wfd' == file_extension:
        remap_dict = OrderedDict({
                              'perfect3000': 'Perfect', 
                              'fiducial3000': 'Fiducial', 
                              'random3000': 'Random',
                              '75SNIa25SNII': 'SN-II 25', 
                              '90SNIa10SNII': 'SN-II 10',
                              '95SNIa5SNII': 'SN-II 5',
                              '98SNIa2SNII': 'SN-II 2',
                              '99SNIa1SNII': 'SN-II 1',
                              '90SNIa10SNIbc': 'SN-Ibc 10',
                              '95SNIa5SNIbc': 'SN-Ibc 5',
                              '98SNIa2SNIbc': 'SN-Ibc 2',
                              '99SNIa1SNIbc': 'SN-Ibc 1',
                              '75SNIa25SNIax': 'SN-Iax 25',
                              '90SNIa10SNIax': 'SN-Iax 10',
                              '95SNIa5SNIax': 'SN-Iax 5',
                              '98SNIa2SNIax': 'SN-Iax 2',
                              '99SNIa1SNIax': 'SN-Iax 1',
                              '95SNIa5SNIa-91bg': 'SN-Ia-91bg 5',
                              '98SNIa2SNIa-91bg': 'SN-Ia-91bg 2',
                              '99SNIa1SNIa-91bg': 'SN-Ia-91bg 1',
                              '98SNIa2AGN': 'AGN 2',
                              '99SNIa1AGN': 'AGN 1',
                              '99SNIa1CART': 'CART 1'
                  })
    else:
        remap_dict = OrderedDict({
                          'perfect3000': 'Perfect', 
                          'fiducial3000': 'Fiducial', 
                          'random3000': 'Random',
                          '72SNIa28SNII': 'SN-II 28',
                          '75SNIa25SNII': 'SN-II 25', 
                          '90SNIa10SNII': 'SN-II 10',
                          '95SNIa5SNII': 'SN-II 5',
                          '98SNIa2SNII': 'SN-II 2',
                          '99SNIa1SNII': 'SN-II 1',
                          '95SNIa5SNIbc': 'SN-Ibc 5',
                          '98SNIa2SNIbc': 'SN-Ibc 2',
                          '99SNIa1SNIbc': 'SN-Ibc 1',
                          '90SNIa10SNIax': 'SN-Iax 10',
                          '95SNIa5SNIax': 'SN-Iax 5',
                          '98SNIa2SNIax': 'SN-Iax 2',
                          '99SNIa1SNIax': 'SN-Iax 1',
                          '99.4SNIa0.6CART': 'CART 0.6',
                          '99.9SNIa0.1SLSN': 'SLSN 0.1'
              })
    return(remap_dict)

In [None]:
remap_dicts = {}
for file_extension in file_extensions:
    thing = make_remap_dict(file_extension)
    tempdict = {}
    for case in cases[file_extension]:
        if case[:-4] in thing.keys():
            tempdict[case[:-4]] = thing[case[:-4]]
        #else:
            #print(case)
    remap_dicts[file_extension] = tempdict#{thing[case[:-4]] for case in cases[file_extension]}

In [None]:
# Mapping the percent contaminated to the colormap.
## size corresponds to remap_dict
def make_color_nums(file_extension):

    if file_extension == 'wfd':
        color_num = np.array([1, 1, 1,                     # Special
                              28, 25, 10, 5, 2, 1,        # II
                              10, 5, 2, 1,                # Ibc
                              25, 10, 5, 2, 1,            # Iax
                              5, 2, 1,                    # 91bg
                              2, 1,                       # AGN
                                 1                        # CART
                          ])                   
    else:
        color_num = np.array([1, 1, 1,                  # Special
                              25, 10, 5, 2, 1,          # II
                              5, 2, 1,                  # Ibc
                              10, 5, 2, 1,              # Iax
                              1                         # CART
                          ]) 
    return(color_num)

In [None]:
color_nums = {}
for file_extension in file_extensions.keys():
    color_nums[file_extension] = make_color_nums(file_extension)

In [None]:
# Color map
rainbow = cm = plt.get_cmap('plasma_r')
cNorm  = colors.LogNorm(vmin=1, vmax=30) #colors.Normalize(vmin=0, vmax=50)
scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=rainbow)
color_map = scalarMap.to_rgba(np.arange(1, 30))

## calculate the curve(s)

KDE for each set of posterior samples

In [None]:
eps = 2. * sys.float_info.min

def safe_log(arr, threshold=eps):
    """
    Takes the natural logarithm of an array that might contain zeros.

    Parameters
    ----------
    arr: ndarray, float
        array of values to be logged
    threshold: float, optional
        small, positive value to replace zeros and negative numbers

    Returns
    -------
    logged: ndarray
        logged values, with small value replacing un-loggable values
    """
    arr = np.asarray(arr)
    arr[arr < threshold] = threshold
    logged = np.log(arr)
    return logged

def make_grid(x, y, x_ngrid=100, y_ngrid=100):
    x_min = x.min()#-1.2
    x_max = x.max()#-0.8
    y_min = y.min()#0.2
    y_max = y.max()#0.4

    x_grid, y_grid = np.mgrid[x_min:x_max:x_ngrid*1.j, y_min:y_max:y_ngrid*1.j]
    x_vec, y_vec = x_grid[:, 0], y_grid[0, :]
    dx = (x_max - x_min) / (x_ngrid - 1)
    dy = (y_max - y_min) / (y_ngrid - 1)

    return(((x_min, y_min), (x_max, y_max)), (x_grid, y_grid), (x_vec, y_vec), (dx, dy))

def make_kde(Xgrid, Ygrid, Xsamps, Ysamps, to_log=False, save=None, one_d=True):
    if not one_d:
        positions = np.vstack([Xgrid.ravel(), Ygrid.ravel()])
        values = np.vstack([Xsamps, Ysamps])
        kernel = sps.gaussian_kde(values, bw_method='scott')
        Z = np.reshape(kernel(positions).T, Xgrid.shape)
    else:
        positions = Xgrid.T[0]
        values = Xsamps
        kernel = sps.gaussian_kde(values, bw_method='scott')
        Z = kernel(positions)
    
    if to_log:
        return safe_log(Z)
    else:
        return Z
#     if save is not None:
# TODO: normalize up here before log!

In [None]:

# alloutputs = pd.DataFrame(columns=['path', 'KLD'])
#     # make reference sample
# with gzip.open(fullpath+refname) as reffn:
#     flatref = pd.read_csv(reffn)
# [w_ref, Omm_ref] = [flatref['w'], flatref['om']]
# ref_extrema, ref_grids, ref_vecs, ref_ds = make_grid(w_ref, Omm_ref)
# (w_vec, Omm_vec) = ref_vecs
# (dw, dOmm) = ref_ds
# ((xmin, ymin), (xmax, ymax)) = ref_extrema
# (w_grid, Omm_grid) = ref_grids
# d_ref = {'w': dw, 'Omm': dOmm}
# grid_ref = {'w': w_grid, 'Omm': Omm_grid}
# kde_ref = make_kde(w_grid, Omm_grid, w_ref, Omm_ref, one_d=True, to_log=True)

In [None]:
def get_posteriors(field, k, casename, nsn, withlowz=True):
    
    if 'perfect' in casename or 'random' in casename or 'fiducial' in casename:
        if str(nsn) not in casename:
            case = casename + str(nsn)
        else:
            case = casename
    else:
        case = casename

    filename = 'chains_'+case

    if withlowz:
        filename = filename+'_lowz_withbias'
    path_pre = '/media/RESSPECT/data/PLAsTiCC/for_metrics/final_data3/' + file_extensions[field] + \
               '/results/v' + str(k) + '/' + str(nsn) + '/posteriors/pkl/'
#     if field == 'ddf':
# #         if k == '':
# #             k = 0
#         path_pre = '/media/RESSPECT/data/PLAsTiCC/for_metrics/final_data/DDF/v'+str(k)+'/posteriors/pkl/'
# #             path_pre = '/media/RESSPECT/data/PLAsTiCC/for_metrics/ddf/posteriors/samples_emille/'
# #             ext = '.csv.gz'
# #         else:
# #             path_pre = '/media/RESSPECT/data/PLAsTiCC/for_metrics/ddf/emille_samples'+str(k)+'/posteriors/'
    ext = '.pkl'
#     elif field == 'wfd':
#         path_pre = '/media/RESSPECT/data/PLAsTiCC/for_metrics/wfd/posteriors/samples_emille'+str(k)+'/'
#         ext = '.csv.gz'
    samppathname = path_pre+filename+ext

    if ext == '.csv.gz':
        with gzip.open(samppathname) as sampfile:
            sampdata = pd.read_csv(sampfile)
    elif ext == '.pkl':
        with open(samppathname, 'rb') as sampfile:
            sampdata = pkl.load(sampfile)
#     print(sampdata)
    return([sampdata['w'], sampdata['om']])

In [None]:
null_cases = ['perfect', 'random', 'fiducial']
ktot = 1
kmin = 0
samp_sizes = [1500, 3000, 6000]
ngrid = 100

In [None]:
outdata = {}
for field in file_extensions:
    outdata[field] = {}
    for casename in null_cases:
        outdata[field][casename] = np.empty((ktot, len(samp_sizes), 2, ngrid))
        for k in range(kmin, ktot, 1):
            for i, nsn in enumerate(samp_sizes):
                kpass = k
                [w_comp, Omm_comp] = get_posteriors(field, kpass, casename, nsn, withlowz=True)#[sampdata['w'], sampdata['om']]
                comp_extrema, comp_grids, comp_vecs, comp_ds = make_grid(w_comp, Omm_comp)
                (w_grid, Omm_grid) = comp_grids
                kde_comp = make_kde(w_grid, Omm_grid, w_comp, Omm_comp, one_d=True, to_log=True)
                outdata[field][casename][k][i] = np.array([w_grid.T[0], kde_comp])
with open('default_kdes.pkl', 'wb') as outfile:
    pkl.dump(outdata, outfile)

In [None]:
outdata = {}
for field in file_extensions:
    outdata[field] = {}
    for casename in cases[field]:
        outdata[field][casename[:-4]] = np.empty((2, ngrid))
        k = '0'
        nsn = '3000'
        [w_comp, Omm_comp] = get_posteriors(field, k, casename[:-4], nsn, withlowz=True)#[sampdata['w'], sampdata['om']]
        comp_extrema, comp_grids, comp_vecs, comp_ds = make_grid(w_comp, Omm_comp)
        (w_grid, Omm_grid) = comp_grids
        kde_comp = make_kde(w_grid, Omm_grid, w_comp, Omm_comp, one_d=True, to_log=True)
        outdata[field][casename[:-4]] = np.array([w_grid.T[0], kde_comp])
with open('testcase_kdes.pkl', 'wb') as outfile:
    pkl.dump(outdata, outfile)

## make plot(s)

In [None]:
def_colors = {'perfect': 'k', 'random': 'tab:red', 'fiducial': 'tab:blue'}
def_styles = {'1500': ':', '3000': '-', '6000': '--'}#{'DDF': '-', 'WFD': '--'}
# def_lowz = {'withbias': , 'nobias':}

with open('default_kdes.pkl', 'rb') as infile:
    indata = pkl.load(infile)

fig, ax = plt.subplots(2, 1, figsize=(6, 7))    
for j, field in enumerate(file_extensions):
    for casename in null_cases:
        ax[j].scatter([0], [0], label=casename, color=def_colors[casename])
        for i, nsn in enumerate(samp_sizes):
            for k in range(ktot-1, ktot):
                w_grid, kde_comp = indata[field][casename][k][i]#[w_grid, kde_comp] = indata[casename]
#                 if k == 0:
#                     lw_boost = 2
# #                     print(kde_comp)
#                 else:
#                     lw_boost = 1
                ax[j].plot(w_grid, np.exp(kde_comp),# label=field+casename, 
                linestyle=def_styles[str(nsn)], color=def_colors[casename], alpha=0.8, linewidth=1.25)
    for nsn in samp_sizes:
        ax[j].plot([0], [0], label=str(nsn), 
                 linestyle=def_styles[str(nsn)], color='tab:green', alpha=1., linewidth=1.25) 
    ax[j].set_yticks([10, 30, 50])
    ax[j].set_yticklabels([10, 30, 50], fontsize=14)
    ax[j].set_ylabel(r'PDF ($w^{-1}$)', fontsize=18)
    ax[j].vlines(-1., ax[j].get_ylim()[0], ax[j].get_ylim()[1], color='gray', alpha=0.5)
    #ax[j].set_ylim(0., 70.)
    if j == 0:
        yset = ax[j].get_ylim()[1]
        ax[j].text(-1.175, 0.85*yset, file_extensions[field], fontsize=20)
        ax[j].set_xticks([])
    # plt.title(field+k)
    if j == 1:
        ax[j].set_xticks([-1.2, -1.1, -1.])
        ax[j].set_xticklabels([-1.2, -1.1, -1.], fontsize=14)
        ax[j].legend(loc='lower left', fontsize=14)#, ncol=2)
        ax[j].set_xlabel(r'$w$', fontsize=18)
        yset = ax[j].get_ylim()[1]
        ax[j].text(-1.175, 0.85*yset, file_extensions[field], fontsize=20)
    ax[j].set_xlim(-1.2, -0.95)
fig.subplots_adjust(wspace=0., hspace=0.)
#plt.savefig('dists_null.pdf', bbox_inches='tight' ,dpi=250)
plt.show()

todo: investigate the runs that are flat KDEs

todo: also with and without bias of lowz sample

In [None]:
rates, contaminants = {}, {}
for field in file_extensions:
    rate, contaminant = {}, {}
    for key in remap_dicts[field]:
        postsplit = remap_dicts[field][key].split()
        if len(postsplit) > 1:
            name = postsplit[0]
            perc = float(postsplit[-1])
#         rate[name] = perc
            rate[key] = perc
            contaminant[key] = name
    rates[field] = rate
    contaminants[field] = contaminant

In [None]:
# for field in file_extensions:
#     plt.hist(rates[field].values(), bins=25, alpha=0.5, label=field)
# plt.legend()

todo: automate dividing into panels

In [None]:
cutoffs = [0., 1., 2., 5., 7.5, 15., 50.]
cutofflabels = ['<1%', '1%', '2%', '5%', '10%', '25%']

panel_groups = {}
for field in file_extensions:
    panel_groups[field] = {j: [] for j in range(6)}

    for i, casefn in enumerate(rates[field]):
        casename = casefn#[:-4]
        rate = rates[field][casename]
        if rate > 0. and rate < 1.:
            panel_groups[field][0].append(casename)
        elif rate >= 1. and rate < 2.:
            panel_groups[field][1].append(casename)
        elif rate >= 2. and rate < 5.:
            panel_groups[field][2].append(casename)
        elif rate >= 5. and rate < 7.5:
            panel_groups[field][3].append(casename)
        elif rate >= 7.5 and rate <= 15.:
            panel_groups[field][4].append(casename)
        elif rate >= 15. and rate <= 50.:
            panel_groups[field][5].append(casename)


In [None]:
ddf_set = set(contaminants['ddf'].values())
# if len(file_extensions) > 1:
#     for field in file_extensions[1:]:
#         base_contaminant_set = set.union(base_contaminant_set, set(contaminants[field].values()))
wfd_set = set(contaminants['wfd'].values())
all_contaminants = np.unique(np.array(list(ddf_set) + list(wfd_set)))
# base_contaminant_set#

color_list = OrderedDict({contaminant: plt.cm.tab10(i) for i, contaminant in enumerate(all_contaminants)})

contaminant_colors = {}
for field in file_extensions:
    contaminant_colors[field] = {}
    for i, contaminant in enumerate(contaminants[field]):
        contaminant_colors[field][contaminant] = color_list[contaminants[field][contaminant]]

In [None]:
axs = {}

with open('testcase_kdes.pkl', 'rb') as infile:
    indata = pkl.load(infile)

for field in file_extensions:
    table_loc = '/media/RESSPECT/data/PLAsTiCC/for_metrics/final_data3/'+file_extensions[field]+'/results/v'+str(3)+'/3000/summary_stats.csv'
    
    df = pd.read_csv(table_loc)
    df = df.set_index('case')

    fig = pylab.figure(figsize=(15, 10))
    bigAxes = pylab.axes(frameon=False)     # hide frame
    bigAxes.set_xticks([])                        # don't want to see any ticks on this axis
    bigAxes.set_yticks([])

    bigAxes.set_title(file_extensions[field], fontsize=20)
    numrows=2
    numcols=3

    for i in range(len(panel_groups[field])):
        per_panel_contaminants = [contaminants[field][panel_groups[field][i][j]] 
                                  for j in range(len(panel_groups[field][i]))]
        uniques, unique_ind = np.unique(per_panel_contaminants, return_index=True)
        
        axs[i] = fig.add_subplot(numrows,numcols,i+1)
        ax = axs[i]

        stylecount = 0
        for j, val in enumerate(unique_ind):
            casename = panel_groups[field][i][val]
            w_grid, kde_comp = indata[field][casename]
            
            if (i > 0):
                ax.plot(w_grid, np.exp(kde_comp), color=contaminant_colors[field][casename], label=per_panel_contaminants[val])
                
            ax.set_xlim(-1.2, -0.9)
            ax.set_xlabel(r'$w$', fontsize=18)
            ax.set_xticks([-1.15, -1.1, -1.05, -1.0, -0.95])
            ax.set_xticklabels([-1.15, -1.1, -1.05,-1.0, -0.95], fontsize=16)
                
                
        if i == 5:    
            l = ax.legend(fontsize=13, loc='upper right', bbox_to_anchor=(1.02, 1.02), title=cutofflabels[i])
        elif i == 0:
            pass
        else:
            l = ax.legend(fontsize=13, loc='upper left', bbox_to_anchor=(-0.02, 1.02), title=cutofflabels[i])
            
        plt.setp(l.get_title(),fontsize=14)
        
        if i < 3 and field == 'ddf':
            axs[i].set_ylim(0., 32.)
        elif i > 2 and field == 'ddf':
            axs[i].set_ylim(0, 45)
            
        if i < 3 and field == 'wfd':
            axs[i].set_ylim(0., 25.)
        elif i > 2 and field == 'wfd':
            axs[i].set_ylim(0, 30)
            
        if i%3 == 0:
            ax.set_ylabel(r'PDF ($w^{-1}$)', fontsize=18)
        elif i in [1,2]:
            ax.set_ylim(axs[0].get_ylim())
        elif i in [4,5]:
            ax.set_ylim(axs[3].get_ylim())
        
    for ii in [1,2,4,5]:
        axs[ii].set_yticks([])
        
    axs[0].set_yticks([0, 10, 20, 30])
    axs[0].set_yticklabels([0, 10, 20, 30], fontsize=14)

    if field == 'ddf':    
        axs[3].set_yticks([0, 10, 20, 30,40])
        axs[3].set_yticklabels([0, 10, 20, 30,40], fontsize=14)
    else:    
        axs[3].set_yticks([0, 5, 15, 25])
        axs[3].set_yticklabels([0, 5, 15, 25], fontsize=14)
        
    for i in range(6):
        if i == 0:
            pass
        else:
            axs[i].vlines(-1., axs[i].get_ylim()[0], axs[i].get_ylim()[1], color='gray', alpha=0.5)
    
    
    fig.subplots_adjust(wspace=0., hspace=0.)
    #pylab.savefig(file_extensions[field]+'combos.png', bbox_inches='tight', pad_inches=0.3, dpi=250)

    fig.show()

TODO: polish these for paper
- linestyle for contamination rate if more than one with same contaminant per panel

In [None]:
import pickle
a_file = open("colors.pkl", "wb")

pickle.dump(contaminant_colors, a_file)

a_file.close()