In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import glob
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from bon.utils import utils
from bon.utils.shotgun_utils import get_powerlaw_baseline, process_powerlaw_data

warnings.filterwarnings("ignore", category=DeprecationWarning)

# Load direct requests and audio/ text repeated sampling

In [None]:
direct_requests =  Path("./data/direct_request.jsonl")
df_direct = utils.load_jsonl_df(direct_requests)

In [None]:
#Baselines
root_dir_baseline = Path("./exp/figure1/baselines/audio-only")
flash_pattern = "flash_direct_request_high_temp1.0_step*.jsonl"
flash_files = glob.glob(str(root_dir_baseline / flash_pattern))

In [None]:
flash_baselines = []
for i, file in enumerate(flash_files):
    flash_df = utils.load_jsonl_df(file)
    flash_df["step"] = i
    flash_baselines.append(flash_df)
df_flash_baseline = pd.concat(flash_baselines)


In [None]:
flash_grp = df_flash_baseline.groupby(["step"])['proportion_flagged_audio'].sum().reset_index().sort_values('proportion_flagged_audio', ascending = False)
flash_grp['proportion_flagged_audio'] = flash_grp['proportion_flagged_audio']/159

In [None]:
dr_idx_map = dict(zip(df_direct.rewrite.tolist(), df_direct.index.tolist()))

In [None]:
df_flash_baseline = pd.DataFrame(get_powerlaw_baseline(df_flash_baseline, dr_idx_map))

# Load and Process Shotgun

In [None]:
root_dir = Path("./exp/figure1")
aug_attack_flash = root_dir / "gemini-1.5-flash-001"
df_flash = process_powerlaw_data(aug_attack_flash, df_direct, "direct_request_search_steps.jsonl", 8, 10, 5)

In [None]:
#@file_cache()
def calculate_asr_trajectories(df, num_repeats=20, num_tiles = 1):
    asrs = []
    for num_repeat in range(num_repeats):
        shuffled_df = df.groupby("i").apply(lambda x: x.sample(frac=1, random_state=num_repeat)).reset_index(drop=True)

        # size (N * T) -> (T, N)
        # N = number of unique i
        # T = number of steps (len(df) / N)
        arr = np.stack(shuffled_df.groupby("i")["flagged"].apply(lambda x: x.to_numpy())).T
        T, N = arr.shape
        arr = np.tile(arr, (num_tiles,1))

        for col in range(arr.shape[1]):
            np.random.shuffle(arr[:, col])

        # Use cumsum to efficiently calculate the cumulative flagged count
        cumulative_flagged = np.cumsum(arr, axis=0)
        idx_flagged = cumulative_flagged > 0

        # Calculate ASR trajectory efficiently
        asr_trajectory = np.sum(idx_flagged, axis=1) / N

        asr_trajectory = asr_trajectory[:T]
        asrs.append(asr_trajectory.tolist())
    
    print(T,N)
    print(len(asrs))
    return asrs

asrs_flash = calculate_asr_trajectories(df_flash, 100, num_tiles=10)
baseline_flash = calculate_asr_trajectories(df_flash_baseline, 100, num_tiles=10)

# Define plotting functions

In [None]:
from bon.utils.power_law import (
    exp_power_law,
    exp_power_law_single_term_no_constant,
    o1_simple_fit_power_law_no_constant,
    power_law,
    power_law_single_term_no_constant,
)


def plot_mean_and_std(ax, asr_mean, asr_std, steps, log_scale=True, epsilon=1e-5, exp_name="ASR", color=None):
    if log_scale:
        line, = ax.plot(steps, -np.log(asr_mean), label=exp_name, color=color)
        ax.fill_between(steps, 
                         -np.log(np.maximum(asr_mean - asr_std, epsilon)), 
                         -np.log(np.maximum(asr_mean + asr_std, epsilon)), 
                         alpha=0.3)
        ax.set_ylabel("-log(ASR)")
        ax.set_xscale("log")
        ax.set_yscale("log")
    else:
        line, = ax.plot(steps, asr_mean, label=exp_name, color=color)
        ax.fill_between(steps, 
                         (asr_mean - asr_std), 
                         (asr_mean + asr_std), 
                         alpha=0.3)
        ax.set_ylabel("ASR")
    
    ax.set_xlabel("N")
    ax.set_xlim(left=1)


def plot_fitted_asr(ax, steps, params, color, log_scale=True, exp_name="", use_label=True):
    if len(params) == 2:
        fitted_asr = power_law_single_term_no_constant(steps, *params) if log_scale else exp_power_law_single_term_no_constant(steps, *params)
        linestyle = '--'
        label = f"{exp_name}: fitted"
    elif len(params) == 3:
        fitted_asr = power_law(steps, *params) if log_scale else exp_power_law(steps, *params)
        linestyle = '--'
        label = f'{exp_name}: {params[0]:.2f}x^{-params[1]:.2f} + {params[2]:.2f}'
    else:
        raise ValueError(f"Unexpected number of params: {len(params)}")
    
    if use_label:
        ax.plot(steps, fitted_asr, linestyle=linestyle, color=color, label=label)
    else:
        ax.plot(steps, fitted_asr, linestyle=linestyle, color=color)

    ax.legend()


In [None]:
import seaborn as sns
from matplotlib.ticker import FuncFormatter

experiments = {
    "flash": asrs_flash,
    "flash_baseline": baseline_flash,
}

method_name = "single_term_no_constant"
method_func = o1_simple_fit_power_law_no_constant
start_n = 0
train_split = None

# Create a new figure and axis for each log_scale value
fig, axs = plt.subplots(1, 2, figsize=(15, 5))

# Define a color palette for color blindness
color_palette = sns.color_palette("tab10")

for i, log_scale in enumerate([False, True]):
    ax = axs[i]
    for j, (exp_name, asrs) in enumerate(experiments.items()):


        # Prepare the data
        asr = np.array(asrs)
        asr_mean = np.mean(asr, axis=0)
        asr_std = np.std(asr, axis=0)
        steps = np.arange(1, len(asr_mean)+1)
        color = color_palette[j % len(color_palette)]
        plot_mean_and_std(ax, asr_mean, asr_std, steps, exp_name=exp_name, log_scale=log_scale, color=color)

        param_steps = steps[start_n:]
        param_asr = asr[:,start_n:]
        param_asr_mean = np.mean(param_asr, axis=0)

        # Get the train and test splits
        if train_split is not None:
            train_steps = param_steps[:train_split]
            train_asr = param_asr[:,:train_split]
            train_asr_mean = np.mean(train_asr, axis=0)
            test_split=train_split
        else:
            train_steps = param_steps
            train_asr = param_asr
            train_asr_mean = param_asr_mean
            test_split=0

        # Check if train_asr is empty
        if len(train_asr) == 0 or train_asr.size == 0:
            print(f"Skipping {exp_name} with {method_name} due to empty train_asr")
            continue

        # Fit the model
        if "chinchilla" in method_name:
            params = method_func(x=train_steps, y=-np.log(train_asr_mean + 1e-5))
        else:
            params = method_func(x=train_steps, y=train_asr)

        steps_fitting = np.arange(1, 10000)
        plot_fitted_asr(ax, steps_fitting, params, color=color, exp_name=exp_name, log_scale=log_scale, use_label=False)

    # Add shaded area to show start and end of train split    
    train_split = train_split if train_split else 0
    start_n = start_n if start_n else 0
    ax.axvspan(start_n, train_split + start_n, color='lightgrey', alpha=0.4)

    # ax.set_title(f'Train Split: {train_split if train_split else "All"}, Start N: {start_n}, Log Scale: {log_scale}')
    ax.grid(True, which="both", ls="-", alpha=0.2)
    if log_scale:
        ax.set_xlim(left=10)
        ax.set_ylim(top=5)
        ax.set_yticks([0.1, 0.2, 0.5, 1, 2, 5])
        ax.yaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{x:g}'))
        ax.set_xticks([10, 100, 1000, 10000])
        ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{x:g}'))
# Move legend outside the plot
axs[0].legend(bbox_to_anchor=(1, -0.15), loc='upper center', borderaxespad=0., ncol=len(experiments), fontsize=14)
# remove legend from second plot
axs[1].legend().remove()

# Set font size for all axes ticks and labels
for ax in axs:
    ax.tick_params(axis='both', which='major', labelsize=10)
    ax.tick_params(axis='both', which='minor', labelsize=10)
    ax.xaxis.label.set_size(14)
    ax.yaxis.label.set_size(14)

# tight layout
# plt.tight_layout()

plt.show()

