In [None]:
import warnings
import numpy as np
import pandas as pd
from glob import glob
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)

In [None]:
def wrap_deg(angle):
    """
    Function to convert angles from radians to degrees and
    then remap them to angles between 0 and 90 degrees
    """
    return 90 - np.abs(np.degrees(angle) - 90)


def read_data(file, index, nth_in):
    """
    Function to read in values located at a specified
    index for every nth line of a data file
    """
    vals = []
    for pos, line in enumerate(file):
        if pos % nth_in == 0:
            vals.append([np.float64(x) for x in line.split()][index])
    return vals


def sort_binary(binary, truths):
    """
    Function to reorder the binaries based on the true parameter
    values and map the old binary ordering to the new ordering
    """
    binaries = np.unique(binary)
    binary_dict = dict(zip(binaries[np.argsort(truths)], binaries))
    return np.vectorize(binary_dict.get)(binary)


def get_prior_type_data(df, prior_type, column, binary):
    """
    Function used to gather data of a specified prior type and reformat it
    into a list of arrays which can be used to construct violin plots
    """
    vals = []
    prior_type_mask = df['prior_type'].to_numpy() == prior_type
    for b in np.unique(binary):
        binary_mask = df['binary'].to_numpy() == b
        vals.append(df[prior_type_mask & binary_mask][column].to_numpy())
    return np.array(vals)


def quantile_cut(data, qt_cut):
    """
    Function used to slice off the tails from a set of distributions below
    and above a specified pair of lower and upper quantiles, respectively
    """
    qts = np.array([np.quantile(vals, qt_cut) for vals in data])
    lower_mask = [vals <= qt.max() for vals, qt in zip(data, qts)]
    upper_mask = [vals >= qt.min() for vals, qt in zip(data, qts)]
    combined_mask = np.array(lower_mask) & np.array(upper_mask)
    return [vals[mask] for vals, mask in zip(data, combined_mask)]


def set_violin_properties(violin_plots, facecolor, edgecolor='black',
                          alpha=0.5, side=None):
    """
    Function used to set the violin plot properties and also provides the
    option to plot only the left or right halves of a set of violins
    """
    for poly_col in violin_plots['bodies']:
        poly_col.set_facecolor(facecolor)
        poly_col.set_edgecolor(edgecolor)
        poly_col.set_alpha(alpha)
        verts = poly_col.get_paths()[0].vertices
        if side == 'left':
            verts[:, 0] = np.clip(verts[:, 0], -np.inf, np.mean(verts[:, 0]))
        if side == 'right':
            verts[:, 0] = np.clip(verts[:, 0], np.mean(verts[:, 0]), np.inf)
        poly_col.set_paths([verts])


def plot_percentiles(violin_plots, data, truths, widths, color):
    """
    Function used to plot the 10th, 50th, and 90th percentiles for a given
    set of violins plots such that each percentile line is only plotted
    inside the violins and does not extend past the boundaries
    """
    pc_bounds = []
    pc_pos = np.array([np.percentile(vals, [10, 50, 90]) for vals in data]).T
    for b, poly_col in enumerate(violin_plots['bodies']):
        pcs = []
        verts = poly_col.get_paths()[0].vertices
        verts_cut = verts[np.abs(verts[:, 0] - truths[b]) > widths/100].T
        for pc in pc_pos:
            y_min = np.abs(verts_cut[1] - pc[b])
            y_idx = np.where(y_min == np.min(y_min))[0]
            x_max = np.abs(verts_cut[0][y_idx] - truths[b])
            x_idx = np.where(x_max == np.max(x_max))[0]
            pcs.append(verts_cut[0][y_idx[x_idx]][0])
        pc_bounds.append(pcs)
    pc_bounds = np.array(pc_bounds).T
    plt.hlines(pc_pos[0], pc_bounds[0], truths, linestyle='--',
               linewidth=1.25, color=color)
    plt.hlines(pc_pos[1], pc_bounds[1], truths, linestyle='-',
               linewidth=1.75, color=color)
    plt.hlines(pc_pos[2], pc_bounds[2], truths, linestyle='--',
               linewidth=1.25, color=color)

In [None]:
################################
prefix = 'eclipsing_'
prior_types = ['em', 'gw']
nth_in_gbmcmc = 20
nth_in_gwemlisa = 1
################################

# Read in true parameter values for selected binaries
gbf_dir = Path.cwd().parent.parent.joinpath('ldasoft')
gbf_file = gbf_dir.joinpath(f'{prefix}gbfisher_parameters.dat')
gbf = pd.read_csv(gbf_file, delimiter=' ')
truths = np.array(wrap_deg(gbf['incl']))

# Read in gwemlisa posteriors
prior_type, binary, residual, true_value = [], [], [], []
for prior in prior_types:
    post_dir = Path.cwd().parent.joinpath(f'{prefix}{prior}prior')
    for post_file in post_dir.glob('**/*row0*/post_equal_weights.dat'):
        b_num = int(post_file.parent.parent.name[6:9])
        t_val = truths[b_num - 1]
        with open(post_file) as raw:
            raw_value = read_data(raw, 1, nth_in_gwemlisa)
        prior_type.extend([prior]*len(raw_value))
        binary.extend([b_num]*len(raw_value))
        residual.extend([val - t_val for val in raw_value])
        true_value.extend([t_val]*len(raw_value))

# Read in gbmcmc posteriors
post_dir = Path.cwd().parent.parent.joinpath(f'data/{prefix}results')
for post_file in post_dir.glob('**/chains/*.dat.1'):
    b_num = int(post_file.parent.parent.name[6:9])
    t_val = truths[b_num - 1]
    with open(post_file) as raw:
        raw_value = read_data(raw, 5, nth_in_gbmcmc)
    prior_type.extend(['gbf']*len(raw_value))
    binary.extend([b_num]*len(raw_value))
    residual.extend([val - t_val for val in wrap_deg(np.arccos(raw_value))])
    true_value.extend([t_val]*len(raw_value))

# Format data for plotting
data = np.vstack([sort_binary(binary, truths), prior_type,
                  residual, true_value]).T
post_columns = ['binary', 'prior_type', 'residual', 'true_value']
post = pd.DataFrame(data=data, columns=post_columns)
for column in post.columns.drop('prior_type'):
    post[column] = pd.to_numeric(post[column])
post = post.sort_values(post_columns)
gw = get_prior_type_data(post, 'gbf', 'residual', binary)
em = get_prior_type_data(post, 'em', 'residual', binary)
gwem = get_prior_type_data(post, 'gw', 'residual', binary)

In [None]:
################################
parameter = 'Inclination'
x_units = 'deg'
y_units = 'deg'
truths_fmt = '%.2f'
gw_qt_cut = [0.001, 0.999]
em_qt_cut = [0.001, 0.999]
gwem_qt_cut = [0.001, 0.999]
cut = np.arange(0, 99, 9)
widths = 1
################################

# Apply cuts to the data
positions = np.sort(truths)[cut]
gw_cut = quantile_cut(gw[cut], gw_qt_cut)
em_cut = quantile_cut(em[cut], em_qt_cut)
gwem_cut = quantile_cut(gwem[cut], gwem_qt_cut)

# Create the figure and axes
fig = plt.figure(figsize=(36, 18))
ax = fig.add_subplot(111)
ax2 = ax.secondary_xaxis('top')

# Plot GW posterior
gw_post = ax.violinplot(gw_cut, positions=positions,
                        widths=widths, showextrema=False)
set_violin_properties(gw_post, facecolor='#20BB00', side='left')
plot_percentiles(gw_post, gw[cut], positions, widths, '#105E00')

# Plot EM posterior
em_post = ax.violinplot(em_cut, positions=positions,
                        widths=widths, showextrema=False)
set_violin_properties(em_post, facecolor='#E67E00', side='left')
plot_percentiles(em_post, em[cut], positions, widths, '#733F00')

# Plot GW+EM posterior
gwem_post = ax.violinplot(gwem_cut, positions=positions,
                          widths=widths, showextrema=False)
set_violin_properties(gwem_post, facecolor='#226EF1', side='right')
plot_percentiles(gwem_post, gwem[cut], positions, widths, '#113779')

# Set up the legend
legend_labels = ['GW Posterior', 'EM Posterior', 'GW+EM Posterior']
ax.legend([gw_post['bodies'][0], em_post['bodies'][0], gwem_post['bodies'][0]],
          legend_labels, loc='upper right', prop={'size': 24})

# Set up the title
title_prefix = prefix.capitalize().replace('_', ' ')
title = f'{title_prefix}Binaries Residual {parameter} vs True {parameter}'
ax.set_title(title, fontsize=40, pad=25)

# Set up the axis labels
xlabel = f'True {parameter}'
ylabel = f'Residual {parameter}'
if x_units is not None:
    xlabel += f' [{x_units}]'
if y_units is not None:
    ylabel += f' [{y_units}]'
ax.set_xlabel(xlabel, fontsize=30, labelpad=15)
ax.set_ylabel(ylabel, fontsize=30, labelpad=15)

# Set up the axis ticks
ax.minorticks_on()
ax.tick_params(which='minor', length=3.5, width=1)
ax.tick_params(which='major', length=6, width=1.25, labelsize=18)

# Set up the true value line markers
ax2.set_xticks(positions)
ax2.xaxis.set_major_formatter(FormatStrFormatter(truths_fmt))
ax2.tick_params(which='major', axis='x', length=6, width=1.25, labelsize=18)
for pos in positions:
    ax.axvline(pos, color='black', linestyle=':', linewidth=1.75)

# Plot horizontal line along x-axis if y=0 is inside the plot window
ylims = np.array(ax.get_ylim())
if not (ylims < 0).all() | (ylims > 0).all():
    ax.axhline(0, color='#333333', linestyle='-', linewidth=1.75, zorder=0)
ax.grid()
plt.show()