In [None]:
cd ..

In [None]:
import numpy as np
import scipy.stats as stats
import matplotlib
import matplotlib.pyplot as plt
from collections import defaultdict, Counter
import seaborn as sns
import os
import glob
import sys
import sqlite3
import matplotlib as mpl

plt.rcParams.update({
    "font.family": "serif",  # use serif/main font for text elements
    "text.usetex": True,     # use inline math for ticks
    "pgf.rcfonts": False,    # don't setup fonts from rc parameters
    "font.size": 12,
    "axes.labelsize": "large",
    "pgf.texsystem": "pdflatex",
    "pgf.preamble": [
        r'\usepackage[T1]{fontenc}',
        r'\usepackage[utf8]{inputenc}',
        r'\usepackage{amsmath}',
        r'\usepackage{newtxtext}',
        r'\usepackage{newtxmath}',
#         r'\usepackage[lite,subscriptcorrection,slantedGreek,nofontinfo,amsbb,eucal]{mtpro2}'
    ]
})
sns.set_style('ticks')

In [None]:
plt.figure(figsize=(1,1))
plt.plot(np.arange(10))

In [None]:
mpl.use('pgf')

In [None]:
version = 3

In [None]:
conn = sqlite3.connect('./data/traces.sqlite', timeout=30, isolation_level=None)
conn.execute("PRAGMA read_uncommitted = true;")

In [None]:
list(conn.execute('select count() from replay_stats where version=3'))

In [None]:
num_cells_created = np.array(list(map(lambda t: t[0], conn.execute(f'select num_cells_created from replay_stats where version={version}').fetchall())))
plt.hist(num_cells_created[num_cells_created<200.], bins=8)

In [None]:
num_cell_execs = np.array(list(map(lambda t: t[0], conn.execute(f'select num_cell_execs from replay_stats where version={version}'))))
plt.hist(num_cell_execs)

In [None]:
num_successful_cell_execs = np.array(list(map(lambda t: t[0], conn.execute(f'select num_successful_cell_execs from replay_stats where version={version}').fetchall())))
plt.hist(num_successful_cell_execs)

In [None]:
def make_linechart_components(name, mark='-', agg=np.mean, prefix='', exception_threshold=1.0, npoints=30):
    line = []
    err = []
    exception_fraction = 'num_exceptions * 1.0 / num_cell_execs'
    for i in range(npoints):
        measurements = np.array(list(map(lambda t: t[0], conn.execute(f"""
        select {prefix}predictive_power_{name}
        from replay_stats 
            where version={version} 
            and {prefix}predictive_power_{name} is not null 
            and num_safety_errors >= {i}
            and {exception_fraction} <= {exception_threshold}
        """).fetchall())))
        line.append(agg(measurements))
        err.append(stats.sem(measurements))
    xs = np.arange(len(line))
    line = np.array(line)
    err = np.array(err)
    plt.plot(xs, line, mark)
    plt.fill_between(xs, line-err, line+err, alpha=.3)

In [None]:
# make_linechart_components('next_cell')
make_linechart_components('new_or_refresher_cells', prefix='macro_')
make_linechart_components('refresher_cells', prefix='macro_')
# make_linechart_components('live_cells')
make_linechart_components('new_live_cells', prefix='macro_')
# make_linechart_components('stale_cells')

In [None]:
def make_compare_next_refresher_plot(savename=None, **kwargs):
    make_linechart_components('next_cell', mark='-', **kwargs)
    make_linechart_components('refresher_cells', mark='--', **kwargs)
    plt.grid(linestyle=':')
    plt.legend((r'\textrm{Next cell}', r'\textrm{Refresher cells}'), loc='upper left')
    plt.xlabel(r'\textrm{Min number safety errors in session}')
    plt.ylabel(r'\textrm{Predictive power}')
    plt.tight_layout()
    if savename is not None:
        plt.savefig(savename)

In [None]:
make_compare_next_refresher_plot(exception_threshold=.5) #, savename='pp-by-num-safety-issues.pgf')

In [None]:
def compute_highlight_measurement(name, prefix='', agg=np.mean, exception_threshold=1.0):
    exception_fraction = 'num_exceptions * 1.0 / num_cell_execs'
    count_col = '1.0' if name == 'next_cell' else f'avg_num_{name}'
    measurements, counts = map(np.array, zip(*conn.execute(f"""
        select {prefix}predictive_power_{name}, {count_col}
        from replay_stats 
            where version={version} 
            and {prefix}predictive_power_{name} is not null 
            and {exception_fraction} <= {exception_threshold}
        """).fetchall()))
    if agg is not None:
        if isinstance(agg, (list, tuple)):
            assert len(agg) == 2
            measurements = agg[0](measurements)
            counts = agg[1](counts)
        else:
            measurements, counts = map(agg, [measurements, counts])
    return measurements, counts

In [None]:
compute_highlight_measurement('new_live_cells', agg=np.mean)

In [None]:
compute_highlight_measurement('live_cells', agg=np.mean)

In [None]:
compute_highlight_measurement('refresher_cells', agg=np.mean)

In [None]:
compute_highlight_measurement('new_or_refresher_cells', agg=np.mean)

In [None]:
compute_highlight_measurement('stale_cells', agg=np.mean)

In [None]:
compute_highlight_measurement('next_cell', agg=np.mean)

In [None]:
hl_set_to_latex = {
    'next_cell': r'$\mathcal{H}_\text{next}$',
    'live_cells': r'$\mathcal{H}_\text{fresh}$',
    'stale_cells': r'$\mathcal{H}_\text{stale}$',
    'refresher_cells': r'$\mathcal{H}_\text{refresher}$',
    'new_live_cells': r"$\mathcal{H}'_\text{fresh}$",
}
agg_to_fun = {
    'avg': np.mean,
    'median': np.median
}
def make_table(highlight_sets=None, agg='avg'):
    if highlight_sets is None:
        highlight_sets = ['next_cell', 'live_cells', 'stale_cells', 'refresher_cells', 'new_live_cells']
        
    agg_fun = agg_to_fun[agg]
    pps = []
    counts = []
    for hls in highlight_sets:
        pp, cnt = compute_highlight_measurement(hls)
        pps.append(pp)
        counts.append(cnt)
    
    table_begin = r'\begin{tabular}{|C{2.2cm}|' + '|'.join('c' for hls in highlight_sets) + '|}'
    table_header = r'\hline\rowcolor[HTML]{C0C0C0}{\bf Quantity} & ' + ' & '.join(hl_set_to_latex[hls] for hls in highlight_sets) + r'\\\hline'
    agg_pp_line = r'avg $\mathcal{P}(\mathcal{H}_*)$ &' + ' & '.join('$' + ('%.3f' % pp) + '$' for pp in pps) + r'\\\hline'
    agg_count_line = 'avg $|\mathcal{H}_*|$ &' + ' & '.join(' $' + ('%.3f' % cnt) + '$' for cnt in counts) + r'\\\hline'
    table_end = r'\end{tabular}'
    return '\n'.join([
        table_begin,
        table_header,
        agg_pp_line,
        agg_count_line,
        table_end,
    ])

In [None]:
print(make_table())