In [None]:
"""
Plots the cumulative Adaptation Gain (y-axis) over iterations (x-axis), for all train users (10 lines).
"""
import datetime
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"  # Print all variables on their own lines
pd.set_option('display.max_rows', 100)

In [None]:
# plain SGD result logs directory
SGD_path = "/your/path/to/logs/GRID_SOLVER-BASE_LR=0-01_SOLVER-MOMENTUM=0-0_SOLVER-NESTEROV=True/2022-09-13_10-53-52_UID958392f7-c477-4a09-a7ac-c72cc81251c2/user_logs"

# Selected metric to plot
metric_key = 'train_action_batch/AG_cumul'

# Plot output config
main_outdir = "../imgs/SGD_users"
title = f"LABEL_HISTOGRAM_{'TRAIN' if TRAIN_USERS_MODE else 'TEST'}"

In [None]:
""" Plot. """
# RESULTS CONFIG
TRAIN_USERS_MODE = True
train_users = ['68', '265', '324', '30', '24', '421', '104', '108', '27', '29']
ZOOM = False

# SAVE CONFIG
parent_dirname = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + "_" + title
parent_dirpath = os.path.join(main_outdir, parent_dirname)
os.makedirs(parent_dirpath, exist_ok=True)

# MPL CONFIG
# Use latex in mpl
plt.rcParams['text.usetex'] = True
plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath}'  #for \text command

plt.rcParams['font.family'] = 'DeJavu Serif'
plt.rcParams['font.serif'] = ['Times New Roman']
plt.rcParams['font.size'] = 16

fontsize = 16
params = {'font.family': 'DeJavu Serif', 'font.serif': 'Times New Roman',
          'axes.labelsize': fontsize + 2, 'axes.titlesize': fontsize, 'font.size': fontsize,
          'legend.fontsize': fontsize, 'xtick.labelsize': fontsize, 'ytick.labelsize': fontsize}
plt.rcParams.update(params)

if ZOOM:
    xlim = (-1, 50)
    ylim = (None, 500)
else:
    xlim = (-1, None)
    ylim = (None, None)

figsize = (8, 4)

ylabel = r"$\text{OAG}_{\mathcal{L}, \text{action}}^+$"
xlabel = "iterations"

# LINE CONFIG
line_colors = sns.color_palette("flare", 10)

cumul = True
linewidth = 0.8
linewidth = 2

if TRAIN_USERS_MODE:
    line_alpha = 0.5
else:
    line_alpha = 0.2

# BARCHART CONFIG
plot_yerror = None
bar_color = 'black'
legend_label = None
bar_align = 'center'
log = False
legend_labels = None
grid = False
bar_alpha = 0.1
bar_marker_offset = 0
bar_width = 0.4


class PlotEntry:
    def __init__(self):
        self.x_vals = None
        self.y_vals = None
        self.label_name = None


def plot_user_histogram_lines_and_avg_bars():
    """ Plot average histogram over user's actions/verbs/nouns."""
    fig, ax = plt.subplots(figsize=figsize, dpi=600)

    user_to_result = {}
    for user_idx, user_id in enumerate(train_users):
        user_csv_path = os.path.join(SGD_path, f"user_{user_id}", "metrics.csv")
        user_df = pd.read_csv(user_csv_path)
        user_to_result[user_id] = user_df

        # Use step, remove NaNs
        user_df[['step', metric_key]].head()

        user_df = user_df[['step', metric_key]]  # Subset columns
        user_df = user_df[user_df[metric_key].notna()]  # Rremove nans

        y_vals = [0] + user_df[metric_key].tolist()
        x_vals = list(range(len(y_vals)))

        # PLOT
        color = line_colors[user_idx % len(line_colors)]
        plt.plot(x_vals, y_vals, alpha=line_alpha, color=color, linewidth=linewidth,
                 #                  marker='|',markersize=2
                 )

    plt.axhline(y=0, color='black', linestyle=':', alpha=0.6, linewidth=0.4)

    # Hide the right and top spines
    ax.spines.right.set_visible(False)
    ax.spines.top.set_visible(False)

    # Only show ticks on the left and bottom spines
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')

    plt.xlim(*xlim)
    plt.ylim(*ylim)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.grid(grid, which='both')
    fig.tight_layout()

    # Save
    if parent_dirpath is not None:
        filename = f"{title}_{xlabel}.pdf"
        filepath = os.path.join(parent_dirpath, filename)
        fig.savefig(filepath, bbox_inches='tight')
        print(f"Saved plot: {filepath}")

    plt.show()
    plt.clf()


# Grids of histograms
# for action_mode in  :
plot_user_histogram_lines_and_avg_bars()