In [None]:
"""
Plot the Re-exposure forgetting results from the corresponding experiment's dump.
Compares the experiment results of SGD vs ER.
"""
import datetime
import os
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import seaborn as sns
import torch
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"
pd.set_option('display.max_rows', 1000)

plt.rcParams['font.family'] = 'DeJavu Serif'
plt.rcParams['font.serif'] = ['Times New Roman']

In [None]:
"""
CONFIG: Add your config params here
"""

########### RESULTS READ CONFIG ###############

# READ PATHS (Dumps)
user_dir_fmt = "user_{}"
dump_filename = "stream_info_dump.pth"

# FT vs Replay: Paths to user_logs in your results of both experiments.
# From these the local result dumps will be extracted.
SGD_parent_outputdir = "/your/path/to/logs/2022-09-23_21-12-33_UIDec506d45-3018-468f-b63a-89744c5d10f9/user_logs/"
replay_parent_outputdir = "/your/path/to/logs/NO_GRID/2022-10-22_00-33-32_UID17d7f432-ee2f-4312-8142-be019ea89b1a/user_logs"
parent_outputdir = replay_parent_outputdir # Choose one of both to scan user directory names

########### PLOT CONFIG ###############
# Labels
label_to_parent_dumpdir = {
    'SGD':SGD_parent_outputdir,
    'Replay':replay_parent_outputdir,
}

# Plot output
main_outdir = "../imgs"
title = "FORG_REEXPOSURE_SGD_VS_REPLAY"

# Plot configs
plot_config = {
    "color": 'royalblue',
    "dpi": 600,
    "figsize": (6,6),
    "xlabel": "re-exposure iterations",
    "ylabel": "",
    "title": None
}

# RESULT METRIC NAMES IN THE DUMP
saved_dumpkeys = [
    'train_action_past/FORG_EXPOSE_loss',
    'train_action_past/FORG_EXPOSE_top1acc',

    'train_verb_past/FORG_EXPOSE_loss',
    'train_verb_past/FORG_EXPOSE_top1acc',
    'train_verb_past/FORG_EXPOSE_top5acc',
    'train_verb_past/FORG_EXPOSE_top20acc',

    'train_noun_past/FORG_EXPOSE_loss',
    'train_noun_past/FORG_EXPOSE_top1acc',
    'train_noun_past/FORG_EXPOSE_top5acc',
    'train_noun_past/FORG_EXPOSE_top20acc'
]
CHOSEN_KEY = 'train_action_past/FORG_EXPOSE_loss' # The one we select for our plot

# Mapping of the metric to the ylabel names
ylabel_map = {
    'train_action_past/FORG_EXPOSE_loss': r"RF_action",
    'train_verb_past/FORG_EXPOSE_loss': r"RF_verb",
    'train_noun_past/FORG_EXPOSE_loss': r"RF_noun",
}
plot_config['ylabel'] = ylabel_map[CHOSEN_KEY]

In [None]:
""" Plot """

sns.set(font_scale=1.6)

# Set fonts
plt.rcParams['font.family'] = 'DeJavu Serif'
plt.rcParams['font.serif'] = ['Times New Roman']
plt.rcParams['font.size'] = 14

# Use latex in mpl
plt.rcParams['text.usetex'] = True
plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath}' #for \text command

# Set latex font in mpl
plt.rcParams['mathtext.fontset'] = 'custom'
plt.rcParams['mathtext.rm'] = 'Bitstream Vera Sans'
plt.rcParams['mathtext.it'] = 'Bitstream Vera Sans:italic'
plt.rcParams['mathtext.bf'] = 'Bitstream Vera Sans:bold'


# Adapt cfg
ylabel_map = {
    'train_action_past/FORG_EXPOSE_loss': r"RF - action",
    'train_verb_past/FORG_EXPOSE_loss': r"RF - verb",
    'train_noun_past/FORG_EXPOSE_loss': r"RF - noun",
}

# naem to color
name_to_color = {
    'SGD':sns.color_palette("Spectral",10)[0],
    'Replay': sns.color_palette("Spectral",10)[9],
}
name_to_legend_name = {
    'SGD': 'SGD',
    'Replay': 'ER',
}

log_scale = True
nb_bins = 10

parent_dirname = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + "_" + title
parent_dirpath = os.path.join(main_outdir, title, parent_dirname)

def add_to_df(action_results_over_time: dict, df_row_list, user_id) :
    assert len(action_results_over_time["delta"]) > 0  # Don't log if state is not kept

    # Get deltas on x-axis
    deltas_x_per_action = defaultdict(list)
    for action, prev_res_over_time in action_results_over_time["prev_after_update_iter"].items():
        cur_res_over_time = action_results_over_time["current_before_update_iter"][action]
        for prev_t, new_t in zip(prev_res_over_time, cur_res_over_time):
            assert new_t >= prev_t, f"New iteration {new_t} <= prev iteration {prev_t}"
            deltas_x_per_action[action].append(new_t - prev_t)

    # Get values on y-axis
    deltas_y_per_action = action_results_over_time["delta"]

    for action, deltas_x in deltas_x_per_action.items():
        deltas_y = deltas_y_per_action[action]
        
        for delta_x, delta_y in zip(deltas_x,deltas_y):
            df_row_list.append({'user':user_id,'iter_delta':delta_x,'forg':delta_y, 'action':action})
#         plt.scatter(deltas_x, deltas_y, color=plot_config['color'])


# Single plot over all users
for y_label_key in ylabel_map.keys():
    
    # Figure to 
    fig, ax = plt.subplots(figsize=(8, 4))
    for method_label, dump_parentpath in  label_to_parent_dumpdir.items():
        print(f"METHOD={method_label}, y_label_key={y_label_key}")
        plot_config['ylabel'] = ylabel_map[y_label_key]

        df_row_list:list[dict] = []
        for idx, user_subdir in enumerate(os.scandir(parent_outputdir)):
            if not user_subdir.is_dir():
                continue
            user_dump_path = os.path.join(dump_parentpath, user_subdir.name, dump_filename)
            assert os.path.isfile(user_dump_path)
            user_id = user_subdir.name.split('_')[-1]

            dump = torch.load(user_dump_path)
            results = dump[y_label_key]

            add_to_df(results, df_row_list, user_id)

        df = pd.DataFrame(df_row_list)

        # Get CORRELATIOn
        corr = scipy.stats.pearsonr(df.iter_delta, df.forg,)
    #     print(f"CORR={corr}")

        print(fr"$\rho = {round(corr[0],2)}$")
        print(fr"$\text{{avg. RF}} = {round(df.forg.mean(),2)} \pm {round(df.forg.sem(),2)}$")
        # PLOTS

        # Filter without 0-iter entries
        df = df.loc[df.iter_delta > 0]
        
        # print average and SEM (without zero-iteration entries)
        avg = df.forg.mean()
        sem = df.forg.sem()
        print(f"MEAN-DELTA={avg}, SE={sem}")

            # To log-scale
        if log_scale:
            df.iter_delta = np.log10(df.iter_delta)
            #df.forg = np.log10(df.forg) # Negative and 0 -> inf

        vals_to_bin = df.iter_delta
        stat_values = df.forg
        bin_means, bin_boundaries, _ = scipy.stats.binned_statistic(vals_to_bin, stat_values, statistic='mean', bins=nb_bins, range=None)
        bin_sems, *_ = scipy.stats.binned_statistic(vals_to_bin, stat_values, statistic=scipy.stats.sem, bins=nb_bins, range=None)
        bin_counts, *_ = scipy.stats.binned_statistic(vals_to_bin, stat_values, statistic='count', bins=nb_bins, range=None)

        print(f"bin_means={bin_means}")
        print(f"bin_sems={bin_sems}")
        print(f"bin_boundaries={bin_boundaries}")
        print(f"bin_counts={bin_counts}")

        x_bin_vals = [round(10**x) for x in bin_boundaries[1:]] # De-logscale
        x_vals = range(len(x_bin_vals))
        plt.errorbar(x_vals, bin_means, yerr=bin_sems,color=name_to_color[method_label],
                    capsize=3, elinewidth=0.8, 
                     markersize=3, marker='v',mfc='black',mec='black',
                     ecolor='black',
                     label=name_to_legend_name[method_label],
                    )

        x_coords = [p.get_x() + 0.5 * p.get_width() for p in ax.patches]
        y_coords = [p.get_height() for p in ax.patches]

    # FINAL PLOT SETTINGS
    ax.set(xlabel='re-exposure iterations (bins)', ylabel=plot_config['ylabel'])
    plt.xticks(x_vals, x_bin_vals, rotation='horizontal')
    ax.grid(True)
    leg = ax.legend()

    plt.tight_layout()
    filename = plot_config['ylabel'].replace(' ','') + '.pdf'
    filepath = os.path.join(parent_dirpath,filename)
    os.makedirs(parent_dirpath, exist_ok=True)
    fig.savefig(filepath)

    plt.show()
    plt.close('all')
