In [None]:
"""
Plot the action distribution for the pretrain and test/train split in comparison.
x-axis: actions, y-axis: label distribution.

Before running this experiment, make sure you have run the Stream meta-data collector (src/continual_ego4d/processing/run_summarize_user_streams.py) and set the resulting paths in this notebook.
"""
import datetime
import json
import os
import pickle
from collections import Counter

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"  # Print all variables on their own lines
pd.set_option('display.max_rows', 1000)

# Use latex in mpl
plt.rcParams['text.usetex'] = True
plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath}'  #for \text command

In [None]:
"""
CONFIG: Add your config params here
"""

########### PLOT CONFIG ###############
TRAIN_USERS_MODE = False

if TRAIN_USERS_MODE:  # Path obtained by Stream meta-data collector (src/continual_ego4d/processing/run_summarize_user_streams.py)
    FILE_TO_ANALYZE = "/your/path/to/logs/2022-10-07_04-49-02_UIDa5c4c52b-a8d8-4155-b1f4-bed9cd82374e/dataset_entries_train_ego4d_LTA_train_usersplit_10users.ckpt"

else:  # Path obtained by Stream meta-data collector (src/continual_ego4d/processing/run_summarize_user_streams.py)
    FILE_TO_ANALYZE = "/your/path/to/logs/2022-10-07_04-33-34_UIDd679068a-dc6e-40ff-b146-70ffe0671a97/dataset_entries_test_ego4d_LTA_test_usersplit_40users.ckpt"

# The pretrain JSON file
pretrain_unsegmented_json = '../data/EgoAdapt/usersplits/ego4d_LTA_pretrain_incl_nanusers_usersplit_148users.json'

main_outdir = "../imgs/pretrain_vs_user_histogram"  # Output path of imgs
title = f"PRETRAIN_VS_USER_LABEL_HISTOGRAM_{'TRAIN' if TRAIN_USERS_MODE else 'TEST'}"  # Title for the plot file

In [None]:
"""Get dictionary with dataframe per user stream (includes all labels and other meta data)."""

with open(FILE_TO_ANALYZE, 'rb') as f:
    ds = pickle.load(f)


def ds_to_user_dfs(ds):
    ret = {}
    for user, user_entries in ds.items():
        # Do all for actions/verbs/nouns
        user_df = pd.json_normalize(user_entries)  # Convert to DF

        # Create action column
        def label_fn(x):
            assert len(x) == 2, "Need two columns to merge"
            if not isinstance(x[0], list):
                assert not isinstance(x[1], list)
                return f"{x[0]}-{x[1]}"

            return [f"{l}-{r}" for l, r in zip(x[0], x[1])]

        user_df['action_label'] = user_df.loc[:, ('verb_label', 'noun_label')].apply(label_fn, axis=1)
        ret[user] = user_df
        print("Created action_label column")
    return ret


dfs = ds_to_user_dfs(ds)  # Preprocess dataframes

In [None]:
""" Preprocess pretraining action sets. """
with open(pretrain_unsegmented_json, 'r') as f:
    action_sets = json.load(f)['user_action_sets']['user_agnostic']

# maps action-idx to {'name':, 'count':} for name of action and frequency count
action_to_name_and_freq_dict = action_sets['action_to_name_dict']

action_to_freq_dict = {a: value_dict['count'] for a, value_dict in action_to_name_and_freq_dict.items()}
action_to_freq_tuples = [(a, cnt) for a, cnt in action_to_freq_dict.items()]
pretrain_action_to_freq_tuples_s = sorted(action_to_freq_tuples, key=lambda x: x[1], reverse=True)
pretrain_action_to_freq_tuples_s

pretrain_actions = [x[0] for x in pretrain_action_to_freq_tuples_s]
pretrain_action_counts = [x[1] for x in pretrain_action_to_freq_tuples_s]

total_pretrain_count = sum(pretrain_action_counts)
pretrain_action_distr = [x / total_pretrain_count for x in pretrain_action_counts]

action_to_plot_idx = {action: idx for idx, action in enumerate(pretrain_actions)}

print(len(pretrain_action_to_freq_tuples_s))

In [None]:
""" Generate plot. """

parent_dirname = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + "_" + title
parent_dirpath = os.path.join(main_outdir, parent_dirname)
os.makedirs(parent_dirpath, exist_ok=True)
user_set = 'train' if TRAIN_USERS_MODE else 'test'

# Set fonts
plt.rcParams['font.family'] = 'DeJavu Serif'
plt.rcParams['font.serif'] = ['Times New Roman']
plt.rcParams['font.size'] = 16

# MPL CONFIG
fontsize = 16
params = {'font.family': 'DeJavu Serif', 'font.serif': 'Times New Roman',
          'axes.labelsize': fontsize + 2, 'axes.titlesize': fontsize, 'font.size': fontsize,
          'legend.fontsize': fontsize, 'xtick.labelsize': fontsize, 'ytick.labelsize': fontsize}
plt.rcParams.update(params)

figsize = (8, 3)
xlim = (-100, None)
ylim = (-0.0005, None)
ylabel = r"$P_{\text{action}}$"

# LINE CONFIG
shade_color = sns.color_palette("rocket")[5]
bar_color = sns.color_palette("rocket")[1]

if TRAIN_USERS_MODE:
    line_alpha = 1
else:
    line_alpha = 1

# BARCHART CONFIG
plot_yerror = None
legend_label = None
bar_align = 'center'
log = False
legend_labels = None
grid = False
bar_alpha = 0.1
bar_marker_offset = 0


def plot_user_histogram_lines_and_avg_bars(dfs, label_col_name):
    """ Plot average histogram over user's actions/verbs/nouns."""
    print(f"Plotting")
    plot_x_vals = list(range(len(pretrain_actions)))

    # Per-user plots: Stacked barchart
    fig, ax = plt.subplots(figsize=figsize, dpi=600)

    if 'action' in label_col_name:
        bar_width = 2
        xlabel = 'action (high to low population frequency)'

    elif 'verb' in label_col_name:
        bar_width = 0.8
        xlabel = 'verb (high to low frequency)'

    elif 'noun' in label_col_name:
        bar_width = 0.8
        xlabel = 'noun (high to low frequency)'

    else:
        raise ValueError()

    # Iterate users, get per-users histogram and plot line
    all_user_distributions = []
    for user_idx, (user_id, user_df) in enumerate(dfs.items()):
        print(f"Idx {user_idx}: User {user_id}")
        cnt = Counter(user_df[label_col_name].tolist())

        user_action_and_count_sorted = sorted([(k, v) for k, v in cnt.items()], key=lambda x: x[1], reverse=True)
        user_plot_y_values = [0] * len(pretrain_actions)

        # Iterate pretrain distribution for ordering of actions
        # Fill in count if present, otherwise 0, asserting all are in pretrain!
        for action, action_count in user_action_and_count_sorted:
            plot_idx = action_to_plot_idx[action]
            user_plot_y_values[plot_idx] = action_count

        p_user_plot_y_values = user_plot_y_values
        all_user_distributions.append(p_user_plot_y_values)

    # Sum and normalize
    all_user_distributions_np = np.asarray(all_user_distributions)
    sum_user_distributions_np = all_user_distributions_np.sum(axis=0)
    sum_user_distributions_np = sum_user_distributions_np / sum(sum_user_distributions_np)  # Make distr
    user_avg_distr_vals = sum_user_distributions_np

    # Rescale to match max of pretrain distribution
    #     rescale_factor = max(pretrain_action_distr)/max(user_avg_distr_vals)
    #     user_avg_distr_vals_rescaled = [rescale_factor*user_avg_distr_val for user_avg_distr_val in user_avg_distr_vals]
    user_avg_distr_vals_rescaled = user_avg_distr_vals

    # Plot avg distr
    ax.bar(plot_x_vals, user_avg_distr_vals_rescaled, bar_width,
           alpha=line_alpha, color=bar_color, edgecolor="none",
           label=rf"$\mathcal{{U}}_{{\text{{{user_set}}}}}$",
           )

    plt.fill_between(plot_x_vals, pretrain_action_distr,
                     alpha=0.5, color=shade_color, linewidth=1,
                     label=r"$\mathcal{U}_\text{population}$",
                     )
    plt.plot(plot_x_vals, pretrain_action_distr, alpha=0.7, color='black', linewidth=0.5, )
    print(f"Finishing up plotting")
    plt.axhline(y=max(pretrain_action_distr), color='black', linestyle=':', alpha=0.6, linewidth=0.4)
    plt.axhline(y=0, color='black', linestyle=':', alpha=0.6, linewidth=0.4)

    # Hide the right and top spines
    ax.spines.right.set_visible(False)
    ax.spines.top.set_visible(False)

    # Only show ticks on the left and bottom spines
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')

    plt.ylim(*ylim)
    plt.xlim(*xlim)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.grid(grid, which='both')
    fig.tight_layout()

    plt.legend(prop={'size': 18})

    # Save
    if parent_dirpath is not None:
        filename = f"{title}_{xlabel}.pdf"
        filepath = os.path.join(parent_dirpath, filename)
        fig.savefig(filepath, bbox_inches='tight')
        print(f"Saved plot: {filepath}")

    plt.show()
    plt.clf()


# Grids of histograms
for action_mode in ['action_label']:  # 'verb_label','noun_label',
    plot_user_histogram_lines_and_avg_bars(dfs, action_mode)