In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import pymongo

COLLECTION_NAME = "PS-A"

client = pymongo.MongoClient("mongodb://localhost:27017/")
db = client["policysmith"]
results = list(db[COLLECTION_NAME].find({}))

In [None]:
run_info = list(db["information"].find({"collection_id": COLLECTION_NAME}))
run_info = run_info[0]['task_args']
baselines = list(db['baselines_percent'].find({'trace_name': run_info['trace'], 'percent': run_info['eval_cache_size']}))

In [None]:
df = pd.DataFrame(results)
df["idx"] = df["iter"] * (df["_sample"].max() + 1) + df["_sample"]
df = df[df['build_status'] & df['exec_status']]
df = df.sort_values(by='idx')

t_idx = 0
while t_idx < len(results):
    result_object = results[t_idx]['eval_results']
    if result_object is not None and 'results' in result_object.keys():
        result_object = result_object['results']
        break
    t_idx += 1

CACHE_SIZES = list(map(lambda x: x['cache_size_mb'], result_object))
    

for cache_size in CACHE_SIZES:
    df[f'obj_hit_rate_{cache_size}'] = df['eval_results'].apply(
       lambda x:  1 - list(filter(lambda y: y['cache_size_mb'] == cache_size, x['results']))[0]['miss_ratio']
    )
    
    df[f'byte_hit_rate_{cache_size}'] = df['eval_results'].apply(
       lambda x:  1 - list(filter(lambda y: y['cache_size_mb'] == cache_size, x['results']))[0]['byte_miss_ratio']
    )

In [None]:
TO_PLOT = "obj_hit_rate"
# TO_PLOT = "byte_hit_rate"

relevant_columns = list(filter(lambda x: x.startswith(TO_PLOT), df.columns))

fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=False)
axes = axes.flatten()

for i, cache_size_to_plot in enumerate(CACHE_SIZES):
    col = f"{TO_PLOT}_{cache_size_to_plot}"
    
    axes[i].plot(df['idx'], df[col].cummax())
    axes[i].set_title(f"{COLLECTION_NAME} - {col}")
    axes[i].set_xlabel("Iteration")
    axes[i].set_ylabel(f"Best {TO_PLOT} so far")
    relevant_baselines = list(filter(lambda x: x['cache_size_mb'] == cache_size_to_plot, baselines))
    if len(relevant_baselines) > 0:
        assert TO_PLOT == "obj_hit_rate"
        # relevant_baselines = sorted(relevant_baselines, key=lambda x: x['miss_ratio'] if TO_PLOT == "obj_hit_rate" else x["byte_miss_ratio"])[:3]
        relevant_baselines = sorted(relevant_baselines, key=lambda x: x['miss_ratio'])[:3]
        print(cache_size_to_plot, relevant_baselines)
        for horiz_line in relevant_baselines:
            axes[i].axhline( y= 1 - horiz_line['miss_ratio'], ls='--')
            axes[i].annotate(f"{horiz_line['cache_name']}",
                xy=(1, 1 - horiz_line['miss_ratio']), xycoords=('axes fraction', 'data'),
                xytext=(4, 0), textcoords='offset points',
                va='center', ha='left', fontsize=6)

plt.tight_layout()
plt.show()

In [None]:
COLUMN_TO_CALC_MAX = 'obj_hit_rate_2233'
increasing_max_idx = df.index[df[COLUMN_TO_CALC_MAX] > df[COLUMN_TO_CALC_MAX].cummax().shift(fill_value=float('-inf'))]

In [None]:
df.loc[increasing_max_idx]