In [None]:
import os
import math
import corner
import itertools
import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import sklearn.metrics as metrics
import matplotlib.lines as mlines

import matplotlib as mpl
from functools import partial
from matplotlib.ticker import AutoMinorLocator

In [None]:
# Ross color scale - ["#000000","#6db6ff","#009292","#ff6db6","#924900","#490092","#006ddb","#b66dff","#004949","#b6dbff","#920000","#ffb6db","#db6d00","#24ff24","#ffff6d"]
# Plasma perceptually uniform - cmap = plt.cm.plasma
# "#F57D15", "#65156E"

dark_purple = "#742881"
dark_green = "#1B7939"
dark_blue = "#1065AB"
dark_red = "#BF2C23"
taupe = "#B9A281"
sage_green = "#0A5C36"
terracotta = "#CB6843"
light_terracotta = "#DA957B"
raspberry = "#E30B5C"
brown = "#A68A64"
dark_brown = "#855e46"
yellow = "#FDB338"
blue = "#025196"

# Nice plasma
# colors_histogram = ["#D44842", "#65156E"]
# colors_scatter = ["#D44842", "#65156E"]

# cmap colors (Currently used)
choose_cmap = mpl.colormaps['turbo']
hex_left = mpl.colors.rgb2hex(choose_cmap(0.05), keep_alpha=True)
hex_right = mpl.colors.rgb2hex(choose_cmap(0.95), keep_alpha=True)

colors_histogram = [hex_left, hex_right]

In [None]:
data_sage = {}
data_pycbc = {}

In [None]:
# create figure and axes
num_rows = len(data_sage.keys())
num_cols = len(data_sage.keys())
fig, axes = plt.subplots(num_rows, num_cols, figsize=(4*num_rows, 4*num_cols), facecolor='white')
gs = gridspec.GridSpec(num_rows, num_cols, wspace=0, hspace=0)
tick_rotation = 45.0 # degrees
flag = True
    
coords = itertools.product(np.arange(num_rows), np.arange(num_cols))
names = itertools.product(list(data_sage.keys()), list(data_sage.keys()))
for name, coord in zip(names, coords):
    param_1, param_2 = name
    if coord[0] == coord[1]:
        assert param_1 == param_2
        paxis = plt.subplot(gs[int(4*coord[0]+coord[1])])
        ax = paxis
        # Sanity check with histogram
        bins = np.histogram(np.hstack((data_pycbc[param_1], data_sage[param_1])), bins=40)[1]
        ax.hist(data_pycbc[param_1], density=False, bins=bins, alpha=0.8, histtype='step', color=colors_histogram[0], linewidth=2.0, linestyle='dashed')
        ax.hist(data_sage[param_1], density=False, bins=bins, alpha=1.0, histtype='step', color=colors_histogram[1], linewidth=2.0, linestyle='solid')
        plt.setp(ax.get_xticklabels(), rotation=tick_rotation, horizontalalignment='right')
        plt.setp(ax.get_yticklabels(), rotation=tick_rotation, horizontalalignment='right')

        if coord[1] == 0:
            ax.set_xticklabels([])
            ax.set_yticklabels([])
        elif coord[0] == num_cols-1:
            paxis.set_xlabel(name_to_label[coord_to_name[coord[1]]])
            ax.set_yticklabels([])
            ax.set_yticks([])
        else:
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_xticks([])
            ax.set_yticks([])

    elif coord[1] > coord[0]:
        ax = plt.subplot(gs[int(4*coord[0]+coord[1])])
        ax.axis('off')
        ax.remove()
        continue

    else:
        ppaxis = plt.subplot(gs[int(4*coord[0]+coord[1])])
        
        ## Scatter plots
        # marker = '+' if n==0 else 'o'
        # s = 20 if n==0 else 25
        # kwargs = {'color':colors_scatter[n]} if n==0 else {'edgecolors':colors_scatter[n], 'facecolors':'none'}
        # ppaxis.scatter(p2_data, p1_data, s=s, alpha=0.8, marker=marker, **kwargs)

        ##  Replacing the scatter plot with hexbinned color plot
        x_data = np.hstack((data_pycbc[coord_to_name[coord[0]]], data_sage[coord_to_name[coord[0]]]))
        y_data = np.hstack((data_pycbc[coord_to_name[coord[1]]], data_sage[coord_to_name[coord[1]]]))
        Z = np.hstack((np.full(len(data_pycbc[coord_to_name[coord[0]]]), 0), np.full(len(data_sage[coord_to_name[coord[0]]]), 1)))
        hexb = plt.hexbin(y_data, x_data, C=Z, reduce_C_function=reduce_function, gridsize=32, cmap=cmap_diverging_dark, vmin=-5, vmax=5)

        if flag:
            # left/right, up/down, width, height
            cax = ppaxis.inset_axes([150.0, 5.0, 3.0, 7.0], transform=ppaxis.transData)
            cbar = fig.colorbar(hexb, cax=cax, orientation='vertical')
            cbar.set_label(r'$\mathregular{N^F_{sage} - N^F_{pycbc}}$', rotation=90)
            flag = False

        name = 'PyCBC' if n==0 else 'Sage'
        """
        corner.hist2d(p2_data, p1_data,
                    ax=ppaxis, color=colors_fill[n],
                    levels=[0.95, 0.68], bins=512, smooth=42.,
                    plot_datapoints=False, plot_contours=True, fill_contours=False,
                    plot_density=False, contour_kwargs=contour_dict[name])
        """
        # Rotate tick labels
        plt.setp(ppaxis.get_xticklabels(), rotation=tick_rotation, horizontalalignment='right')
        plt.setp(ppaxis.get_yticklabels(), rotation=tick_rotation, horizontalalignment='right')

        if coord[1] == 0:
            ppaxis.set_ylabel(name_to_label[param_1])
            if coord[0] != num_rows - 1:
                ppaxis.set_xticklabels([])
                ppaxis.set_xticks([])
        if coord[0] == num_cols-1:
            ppaxis.set_xlabel(name_to_label[param_2])
            if coord[1] != 0:
                ppaxis.set_yticklabels([])
                ppaxis.set_yticks([])
        if coord[1] !=0 and coord[0] != num_cols-1:
            ppaxis.set_xticklabels([])
            ppaxis.set_yticklabels([])
            ppaxis.set_xticks([])
            ppaxis.set_yticks([])

## save the plot
fig.subplots_adjust(wspace=0)
fig.subplots_adjust(hspace=0)
# Legend
sage_line = mlines.Line2D([], [], color=colors_histogram[1], label='Sage', linewidth=4.0)
pycbc_line = mlines.Line2D([], [], color=colors_histogram[0], label='PyCBC', linewidth=4.0, linestyle='dashed')
# left/right, up/down, width, height
plt.legend(handles=[sage_line, pycbc_line], bbox_to_anchor=(-0.45, 0.33, 1., 6.5), loc='center left', frameon=False, fontsize=22, handlelength=2.6)
# Found and missed numbers
sage_set = set(data_sage['mchirp'])
pycbc_set = set(data_pycbc['mchirp'])
found_by_both = len(set.intersection(sage_set, pycbc_set))
found_only_by_sage = len(sage_set - pycbc_set)
found_only_by_pycbc = len(pycbc_set - sage_set)
plt.text(0.35, 3.2, 'Found by both = {}'.format(str(found_by_both)), transform=ax.transAxes, horizontalalignment='right', fontsize=15)
plt.text(0.35, 3.2-0.125, 'Found only by Sage = {}'.format(str(found_only_by_sage)+'  '), transform=ax.transAxes, horizontalalignment='right', fontsize=15)
plt.text(0.35, 3.2-0.250, 'Found only by PyCBC = {}'.format(str(found_only_by_pycbc)+'  '), transform=ax.transAxes, horizontalalignment='right', fontsize=15)

plt.savefig("./evaluation_plots/compare_hexbinned.png", bbox_inches='tight', dpi=300)