In [None]:
"""
Grouped-barplot comparing gradient cosine-similarity of current batch with previous points k steps in history of the learning trajectory.

This is obtained by running experiments and setting in the config:
    cfg.METHOD.ANALYZE_GRADS_WINDOW = True  # Compare current grad with grad in prev
    cfg.METHOD.MAX_ANALYZE_GRADS_WINDOW_SIZE = 10 # Window size to look back at

Download the results in a CSV from WandB of these experiments, and make sure the '*_grad_cos_sim*' metrics are included.
These results are then visualized in this notebook.
"""
import datetime
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"
pd.set_option('display.max_rows', 1000)

plt.rcParams['font.family'] = 'DeJavu Serif'
plt.rcParams['font.serif'] = ['Times New Roman']

In [None]:
"""
CONFIG: Add your config params here
"""

########### READ RESULTS CONFIG ###############
# Open wandb downloaded csv (e.g. from GUI)
csv_path = '/your/path/to/wandb_GUI_downloaded_csv/wandb_export_2022-10-23T15_13_53.325-07_00.csv'

# For which model parts we check the gradients from
model_parts = ["full", "slow", "fast", "head", "feat", ]

########### OUTPUT CONFIG ###############
title = "SGD_GRAD_ANALYSIS"
main_outdir = f"../imgs/{title}"

In [None]:
""" Get dataframe. """
orig_df = pd.read_csv(csv_path)

# Reformat cols to have per metric, the SE and mean in separate columns, and 1->10 in rows, take 'steps' as another column
df_dict_list = []
for nb_steps_lookback in range(1, 11):  # 1 row per step
    single_df_row_dict = {'step': nb_steps_lookback, }

    for model_part in model_parts:
        mean = orig_df[
            f"adhoc_users_aggregate/analyze_action_batch/LOOKBACK_STEP_{nb_steps_lookback}/{model_part}_grad_cos_sim/mean"].to_list()[0]
        SE = orig_df[
            f"adhoc_users_aggregate/analyze_action_batch/LOOKBACK_STEP_{nb_steps_lookback}/{model_part}_grad_cos_sim/SE"].to_list()[0]

        single_df_row_dict[f"adhoc_users_aggregate/analyze_action_batch/{model_part}_grad_cos_sim/mean"] = mean
        single_df_row_dict[f"adhoc_users_aggregate/analyze_action_batch/{model_part}_grad_cos_sim/SE"] = SE

    # Update
    df_dict_list.append(single_df_row_dict)

transformed_df = pd.DataFrame(df_dict_list)
transformed_df.set_index('step',inplace=True) # Set step as index (for plotting later)

In [None]:
""" Plot. """
plot_df = transformed_df
GROUPS_TO_PLOT={ # Map to y-label
    'adhoc_users_aggregate/analyze_action_batch/full_grad_cos_sim':r'$F \circ H$',
    'adhoc_users_aggregate/analyze_action_batch/feat_grad_cos_sim':r'$F$',
    'adhoc_users_aggregate/analyze_action_batch/head_grad_cos_sim':r'$H$',
}

sns.set(font_scale=1.6)
plt.rcParams['font.family'] = 'DeJavu Serif'
plt.rcParams['font.serif'] = ['Times New Roman']
plt.rcParams['font.size'] = 18

# Use latex in mpl
plt.rcParams['text.usetex'] = True
plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath}' #for \text command

# Save config
parent_dirname = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + "_" + title
parent_dirpath = os.path.join(main_outdir, parent_dirname)

# Plot configs
plot_config = {
    "color": 'royalblue',
    "dpi": 600,
    "figsize": (8,4),
    "xlabel": "re-exposure iterations",
    "ylabel": "",
    "title": None
}

colors = [
    sns.color_palette("rocket")[1],
    sns.color_palette("rocket")[3],
    sns.color_palette("rocket")[5],
]

yerr_names = []
y_names = []
legend_names = []
for name, legend_name in GROUPS_TO_PLOT.items():
    yerr_names.append(f"{name}/SE")
    y_names.append(f"{name}/mean")
    legend_names.append(legend_name)

yerr = plot_df[yerr_names].to_numpy().T

fig, ax = plt.subplots(figsize=plot_config['figsize'])
plot_df[y_names].plot(kind='bar', yerr=yerr, alpha=1, error_kw=dict(ecolor='k'),ax=ax,capsize=2,rot=0, color=colors)

ax.legend([r"$F \circ H$",r"$F$",r"$H$"])
ax.set(xlabel=r'$k$', ylabel=r'$\cos_\angle (g_t, \  g_{t-k})$')
ax.legend(legend_names,loc='lower right',ncol=1)

plt.tight_layout()
filename = f'{title}.pdf'
filepath = os.path.join(parent_dirpath,filename)
print(f"Saving at: {filepath}")
os.makedirs(parent_dirpath, exist_ok=True)
fig.savefig(filepath)

plt.show()
plt.close('all')
