In [None]:
"""
Plot correlated and decorrelated OAG over a range of iterations/batch for online finetuning.

Before running these plots, make sure you run 'src/continual_ego4d/processing/run_postprocess_metrics_dump.py' to calculate and upload the correlated/decorrelated OAG results to WandB. Then, download the WandB results in a CSV (e.g. in the GUI), and refer to this CSV's path in this notebook.
"""
import datetime
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"
pd.set_option('display.max_rows', 1000)

plt.rcParams['font.family'] = 'DeJavu Serif'
plt.rcParams['font.serif'] = ['Times New Roman']

In [None]:
"""
CONFIG: Add your config params here
"""

########### READ RESULTS CONFIG ###############
# Open wandb downloaded csv (e.g. from GUI)
csv_path = '/your/path/to/wandb_GUI_downloaded_csv/wandb_export_2022-10-26T15_27_02.501-07_00.csv'

########### OUTPUT CONFIG ###############
title = "MULTI_ITER_ANALYSIS_NEW"
main_outdir = f"../imgs/{title}"

In [None]:
# Get dataframe
orig_df = pd.read_csv(csv_path)
orig_df.set_index('TRAIN.INNER_LOOP_ITERS', inplace=True)  # Set step as index (for plotting later)
plot_df = orig_df

# Map metric name to y-label
GROUPS_TO_PLOT = {
    # TOP1 balanced acc (OAG/HAG)
    #     'adhoc_users_aggregate/train_action_batch/balanced_top1_acc/adhoc_AG':r"$\overline{\text{OAG}}_{\text{action}}$",
    #     'adhoc_users_aggregate/test_action_batch/balanced_top1_acc/adhoc_hindsight_AG':r"$\overline{\text{HAG}}_{\text{action}}$",

    # Decorrelated vs correlated
    'adhoc_users_aggregate/train_action_batch/balanced_top1_acc/decorrelated/adhoc_AG': r"$\overline{\text{OAG}}^\text{decor.}_{\text{action}}$",
    'adhoc_users_aggregate/train_action_batch/balanced_top1_acc/correlated/adhoc_AG':
        r"$\overline{\text{OAG}}^\text{cor.}_{\text{action}}$",
}

# Plot settings
plt.rcParams['font.family'] = 'DeJavu Serif'
plt.rcParams['font.serif'] = ['Times New Roman']
plt.rcParams['font.size'] = 20
plt.rcParams['text.usetex'] = True
plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath}'  #for \text command

colors = sns.color_palette()
plot_config = {
    "color": 'royalblue',
    "dpi": 600,
    "figsize": (8, 4),
    "xlabel": "re-exposure iterations",
    "ylabel": "",
    "title": None
}

# Save config
parent_dirname = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + "_" + title
parent_dirpath = os.path.join(main_outdir, parent_dirname)


fig, ax = plt.subplots(figsize=plot_config['figsize'])

# plot multiple lines
x_vals = plot_df.index.tolist()
legend_names = []
# convert the std columns to an array
for idx, (col_name, label_name) in enumerate(GROUPS_TO_PLOT.items()):
    yerr_name = f"{col_name}/SE"
    y_name = f"{col_name}/mean"
    legend_names.append(label_name)

    y_vals = plot_df[y_name].tolist()
    yerr = plot_df[yerr_name].tolist()
    color = colors[idx]
    plt.errorbar(x_vals, y_vals, yerr=yerr, color=color,
                 capsize=3, elinewidth=0.8,
                 markersize=3, marker='v', mfc='black', mec='black',
                 ecolor='black',
                 label=label_name,
                 )

    plt.plot(x_vals, y_vals, label=label_name, )
    y_low = [y - y_err for y, y_err in zip(y_vals, yerr)]
    y_high = [y + y_err for y, y_err in zip(y_vals, yerr)]

ax.grid(axis='y')
ax.set(xlabel=r'Updates per batch',
       #        ylabel=r'$\cos_\angle (g_t, \  g_{t-k})$'
       )

ax.spines.right.set_visible(False)
ax.spines.top.set_visible(False)

ax.legend(legend_names, loc='best', ncol=4)

plt.tight_layout()
filename = f'{title}.pdf'
filepath = os.path.join(parent_dirpath, filename)
print(f"Saving at: {filepath}")
os.makedirs(parent_dirpath, exist_ok=True)
fig.savefig(filepath)

plt.show()
plt.close('all')
