In [2]:
import hydra

import numpy as np
import pandas as pd
import plotnine as pn

In [3]:
df = pd.read_csv("final_points.csv")

In [4]:
df_baseline = pd.read_csv("scrambled_utility_points.csv")

In [5]:
# Helper defining mean and conf intervals

from scipy import stats

# Calculate the mean and confidence intervals for each iteration
def mean_confidence_interval(data: pd.DataFrame, confidence=0.95) -> tuple[float]:
    mean, se = np.mean(data), stats.sem(data)
    interval = se * stats.t.ppf((1 + confidence) / 2., len(data)-1)
    return (mean, mean - interval, mean + interval)

def mean_conf_df(trajectory_data: pd.DataFrame, colname: str, groupby_cols: list[str] = ['iteration', 'discriminative_need_gamma'],):
    df_summary = trajectory_data.groupby(groupby_cols)[colname].apply(mean_confidence_interval).apply(pd.Series)
    df_summary.columns = [f'mean_{colname}', 'lower_ci', 'upper_ci']
    df_summary = df_summary.reset_index()
    return df_summary


In [None]:

seed_avgd_data = mean_conf_df(df, "min_epsilon", groupby_cols=[ "discriminative_need_gamma"]) # mean across seeds

plot = (
   pn.ggplot(
      seed_avgd_data,
      pn.aes(
         x="discriminative_need_gamma",
         y="mean_min_epsilon",
         # color="seed",
      #  color="discriminative_need_gamma",
      #  group="discriminative_need_gamma",
      )
   )
   + pn.geom_line(
      # df_final_points,
      # pn.aes(
      #     y="min_epsilon",
      # ),
      stat = "summary",
      size=1, 
      color="blue",
   )
   + pn.geom_point(size=2, fill="blue")
   + pn.geom_ribbon(
      pn.aes(
         ymin = "lower_ci",
         ymax = "upper_ci",
      ),
      fill = "blue",
      alpha=0.5,
   )

   + pn.scale_x_log10(breaks = [1e-8, 1e-6, 1e-4, 1e-2, 1])
   # + pn.scale_y_log10()

   + pn.xlab("$\gamma$")
   + pn.ylab("Converged efficiency loss")
   + pn.theme_classic()

   + pn.theme(
      # Axis font
      axis_title=pn.element_text(size=32),
      axis_text=pn.element_text(size=24),
      # Legend position
      # legend_position=(0.75, 0.4),
      # legend_position='none',
      # legend_direction = "vertical",
      # Legend box
      legend_background=pn.element_rect(
         fill="lightgrey",
         size=0.5, linetype="solid", 
         color="darkblue",
      ),
   )
   # guide not working, maybe should just use R
   + pn.guides(
      fill = pn.guide_colorbar(
         title_position = "bottom", 
         label_position = "bottom",
         override_aes = {"alpha":0.8},
         ),
   )     

)
print(plot)