In [None]:
"""
Plot analysis of classifier weight and bias norms.

Need to launch with .sh script importing PYTHON_PATH for loading ckpt.
Before running this experiment, make sure you have run the Stream meta-data collector (src/continual_ego4d/processing/run_summarize_user_streams.py) and set the resulting paths in this notebook.
"""
import datetime
import json
import os
import pickle
from collections import Counter

import matplotlib.pyplot as plt
import pandas as pd
import scipy
import seaborn as sns
import torch
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"  # Print all variables on their own lines
pd.set_option('display.max_rows', 1000)

In [None]:
"""
CONFIG: Add your config params here
"""

########### PLOT CONFIG ###############
plot_modes = ['verb', 'noun'] # Make plot for verbs or for nouns classifier
plot_mode = plot_modes[0]

########### DATA CONFIG ###############
# Path obtained by Stream meta-data collector (src/continual_ego4d/processing/run_summarize_user_streams.py)
train_user_stream_summary_path = '/your/path/to/logs/2022-09-12_18-25-12_UID6131a811-f2c5-4479-a7ff-08dc74d4f9fc/dataset_entries_train_FEWSHOT=False_ego4d_LTA_train_usersplit_10users.ckpt' # Excludes unseen actions during pretrain

# Path to direct JSON from pretraining
pretrain_path = '../data/EgoAdapt/usersplits/ego4d_LTA_pretrain_incl_nanusers_usersplit_148users.json'

########### MODELS CONFIG ###############
# SELECT MODEL TO ANALYZE

# SGD training full model
sgd_ckp_model_dirs = '/your/path/to/logs/GRID_SOLVER-BASE_LR=0-01_SOLVER-MOMENTUM=0-0_SOLVER-NESTEROV=True/2022-09-13_10-53-52_UID958392f7-c477-4a09-a7ac-c72cc81251c2/checkpoints'

# Replay full storage, lr 0.01
replay_ckp_model_dirs = '/your/path/to/logs/GRID_METHOD-REPLAY-MEMORY_SIZE_SAMPLES=1000000_METHOD-REPLAY-STORAGE_POLICY=window/2022-09-15_18-23-13_UID09ab4f67-814b-4331-aa80-839cd99a1d9f/checkpoints'

# SGD training classifier only:
sgd_classifier_only_ckp_model_dirs = "/your/path/to/logs/GRID_SOLVER-BASE_LR=0-1/2022-09-19_12-16-57_UID11a32c82-471b-4921-a5c7-0ee393fdabf0/checkpoints"

# Pretrained model
pretrain_model_path = '/your/path/to/pretrain_148usersplit_incl_nan/2022-09-05_10-34-05_UIDd05ed672-01c5-4c3c-b790-9d0c76548825/checkpoints/best_model.ckpt'

In [None]:
"""
 Get all seen verbs/nouns/actions in pretrain.
"""
with open(pretrain_path, "r") as f:
    pretrain_dataset = json.load(f)

pretrain_action_sets = pretrain_dataset['user_action_sets']['user_agnostic']

pretrain_verb_freq_dict = {
    int(a): a_dict['count'] for a, a_dict in pretrain_action_sets['verb_to_name_dict'].items()
}
pretrain_noun_freq_dict = {
    int(a): a_dict['count'] for a, a_dict in pretrain_action_sets['noun_to_name_dict'].items()
}

In [None]:
"""Get dictionary with dataframe per user stream (includes all labels and other meta data)."""

def ds_to_user_dfs(ds):
    ret = {}
    for user, user_entries in ds.items():
        # Do all for actions/verbs/nouns
        user_df = pd.json_normalize(user_entries)  # Convert to DF

        # Create action column
        def label_fn(x):
            assert len(x) == 2, "Need two columns to merge"
            if not isinstance(x[0], list):
                assert not isinstance(x[1], list)
                return f"{x[0]}-{x[1]}"

            return [f"{l}-{r}" for l, r in zip(x[0], x[1])]

        user_df['action_label'] = user_df.loc[:, ('verb_label', 'noun_label')].apply(label_fn, axis=1)
        ret[user] = user_df
        print("Created action_label column")
    return ret


with open(train_user_stream_summary_path, 'rb') as f:
    ds = pickle.load(f)
dfs = ds_to_user_dfs(ds)

In [None]:
"""
COLLECT ALL WEIGHTS AND BIASES for verb or noun (select which one).
These stats are later used for plotting.
"""
head_layers = [
    'model.head.0.projections.0.weight',
    'model.head.0.projections.0.bias',
    'model.head.0.projections.1.weight',
    'model.head.0.projections.1.bias']

verb_bias_name = 'model.head.0.projections.0.bias'
verb_weight_name = 'model.head.0.projections.0.weight'
noun_bias_name = 'model.head.0.projections.1.bias'
noun_weight_name = 'model.head.0.projections.1.weight'

verb_idx_to_keep = list(pretrain_verb_freq_dict.keys())
noun_idx_to_keep = list(pretrain_noun_freq_dict.keys())

max_verbs_total = 115
max_nouns_total = 478
print(f"VERBS: Keep {len(verb_idx_to_keep)}/{max_verbs_total}")
print(f"NOUNS: Keep {len(noun_idx_to_keep)}/{max_nouns_total}")

if plot_mode == 'verb':
    bias_name = verb_bias_name
    weight_name = verb_weight_name
    idx_to_keep = verb_idx_to_keep

    pretrain_bias_name = 'model.head.projections.0.bias'
    pretrain_weight_name = 'model.head.projections.0.weight'
    stream_stats_key = 'verb_label'
    max_total = max_verbs_total

elif plot_mode == 'noun':
    bias_name = noun_bias_name
    weight_name = noun_weight_name
    idx_to_keep = noun_idx_to_keep

    pretrain_bias_name = 'model.head.projections.1.bias'
    pretrain_weight_name = 'model.head.projections.1.weight'
    stream_stats_key = 'noun_label'
    max_total = max_nouns_total


def get_bias_selection_distr(model_state_dict, verb_bias_name, verb_idx_to_keep, full_bias_t=None):
    if full_bias_t is None:
        full_bias_t = model_state_dict[verb_bias_name]

    verb_bias_t = full_bias_t[verb_idx_to_keep]  # Drop unseens in pretrain
    verb_bias_t_pos = verb_bias_t.pow(2).sqrt()  # Squared L2
    verb_bias_tn = (verb_bias_t_pos) / torch.sum(verb_bias_t_pos)  # Normalize to distr
    verb_bias_list = verb_bias_tn.tolist()
    return verb_bias_list


def get_weight_selection_distr(model_state_dict, verb_weight_name, verb_idx_to_keep, full_weight_t=None):
    if full_weight_t is None:
        full_weight_t = model_state_dict[verb_weight_name]  # torch.Size([115, 2304])

    verb_weight_t = full_weight_t[verb_idx_to_keep]  # torch.Size([106, 2304])
    verb_weight_norms_t = verb_weight_t.pow(2).sum(dim=1)  # All >= 0
    verb_weight_norms_tn = verb_weight_norms_t / torch.sum(verb_weight_norms_t)
    verb_weight_list = verb_weight_norms_tn.tolist()
    return verb_weight_list


# GET PRETRAINED
ckpt = torch.load(pretrain_model_path, map_location='cpu')
pretrain_state_dict = ckpt['state_dict']

verb_bias_distr_pretrain = get_bias_selection_distr(pretrain_state_dict, pretrain_bias_name, idx_to_keep)
verb_weight_distr_pretrain = get_weight_selection_distr(pretrain_state_dict, pretrain_weight_name, idx_to_keep)
# For pretrain_state_dict.keys(): 'model.head.projections.0.weight', 'model.head.projections.0.bias', 'model.head.projections.1.weight', 'model.head.projections.1.bias']

# GET USERS
model_results = {
    'sgd': {'path': sgd_ckp_model_dirs},
    'replay': {'path': replay_ckp_model_dirs},
    'sgd_classifier_only': {'path': sgd_classifier_only_ckp_model_dirs},
}

for name, model_result in model_results.items():
    print(f"Processing: {name}")

    df_list: list[dict] = []
    for idx, (user_id, user_df) in enumerate(dfs.items()):
        print(f"Processing user-id:{user_id}")

        # STREAM STATS
        cnt_dict = Counter(user_df[stream_stats_key].tolist())
        verb_cnt_t = torch.zeros(max_total)  # Original idxs

        for verb_idx, cnt in cnt_dict.items():
            verb_cnt_t[verb_idx] = cnt

        verb_cnt_t = verb_cnt_t[idx_to_keep]  # Subset idxs
        verb_cnt = verb_cnt_t.tolist()
        verb_distr_stream = (verb_cnt_t / verb_cnt_t.abs().sum()).tolist()  # Divide by sum

        # MODEL 
        user_path = os.path.join(model_result['path'], f"user_{user_id}", 'last.ckpt')
        ckpt = torch.load(user_path, map_location='cpu')
        model_state_dict = ckpt['state_dict']

        # GET DELTAS w.r.t. Pretrain
        # Biases
        bias_deltas = (model_state_dict[bias_name] - pretrain_state_dict[pretrain_bias_name])[idx_to_keep]
        bias_deltas_n = (bias_deltas / bias_deltas.abs().sum()).tolist()  # Normalize by absolute mass

        # Weights
        weight_norm_deltas = (model_state_dict[weight_name].pow(2).sum(dim=1).sqrt() - pretrain_state_dict[
            pretrain_weight_name].pow(2).sum(dim=1).sqrt())[idx_to_keep]
        weight_norm_deltas_n = (weight_norm_deltas / weight_norm_deltas.abs().sum()).tolist()

        for idx in range(len(idx_to_keep)):
            df_list.append({
                f"id_{plot_mode}": idx,
                "user": user_id,
                'stream_cnt': verb_cnt[idx],
                # Distributions:
                "weight_norm_delta_p": weight_norm_deltas_n[idx],
                "bias_delta_p": bias_deltas_n[idx],
                'stream_p': verb_distr_stream[idx],
            })

    # Store results for sgd/replay
    users_df = pd.DataFrame(df_list)
    model_result['df'] = users_df

In [None]:
"""
Plot
"""

ZOOM = True
xlim = (-2, None)
PLOT_SELECTION = ['sgd', 'sgd_classifier_only']  # 'replay'

legend_name_mapping = {
    'sgd': r'SGD - $F \circ H$',
    'sgd_classifier_only': r'SGD - $H$',
    'replay': 'replay',
}

# Set fonts
plt.rcParams['font.family'] = 'DeJavu Serif'
plt.rcParams['font.serif'] = ['Times New Roman']
plt.rcParams['font.size'] = 14

# Use latex in mpl
plt.rcParams['text.usetex'] = True
plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath}'  #for \text command

# Set latex font in mpl
plt.rcParams['mathtext.fontset'] = 'custom'
plt.rcParams['mathtext.rm'] = 'Bitstream Vera Sans'
plt.rcParams['mathtext.it'] = 'Bitstream Vera Sans:italic'
plt.rcParams['mathtext.bf'] = 'Bitstream Vera Sans:bold'

figsize = (5, 4)

# Paths
for plot_weights in [True, False]:
    main_outdir = "../imgs/classifier_analysis_final"
    title = f"classifier_analysis_{plot_mode}_"
    title += "weight" if plot_weights else "bias"
    parent_dirname = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + "_" + title
    parent_dirpath = os.path.join(main_outdir, parent_dirname)
    filename = f"{title}.pdf"
    os.makedirs(parent_dirpath, exist_ok=True)

    # PLOT
    fig = plt.figure(figsize=figsize, dpi=600)  # So all bars are visible!
    ax = plt.subplot()
    labels = []

    # Plot histogram distr
    sgd_df = model_results['sgd']['df']
    df_group = sgd_df.groupby([f"id_{plot_mode}"], as_index=False).agg({'stream_p': 'mean'})

    sort_idxs = list(reversed(df_group['stream_p'].argsort()))
    df_group = df_group.loc[sort_idxs]

    x = list(range(len(df_group[f"id_{plot_mode}"])))
    y = df_group['stream_p'].tolist()
    plt.fill_between(x, y, alpha=0.6, label=r"$P_{\text{label}}$", color='gray')

    # Find first zero
    zero_idx = 0
    while y[zero_idx] > 0:
        zero_idx += 1

    if ZOOM:
        plt.xlim(xlim[0], zero_idx)
    else:
        plt.axvline(zero_idx, color='gray', linestyle='--', linewidth=0.8)
        plt.xlim(*xlim)
    plt.axhline(0, color='gray', linestyle='-', linewidth=0.8)

    labels.append(rf"stream")

    # Plot deltas for weights: Replay vs SGD
    colors = {
        'sgd': sns.color_palette("Spectral", 10)[0],
        'replay': 'black',
        'sgd_classifier_only': sns.color_palette("Spectral", 10)[9],
    }
    linestyles = {
        'sgd': '-',
        'replay': (0, (1, 1)),
        'sgd_classifier_only': '-',
    }
    alphas = {
        'sgd': 1,
        'replay': 0.6,
        'sgd_classifier_only': 0.7,
    }
    if plot_weights:
        ax.set(xlabel=plot_mode, ylabel=rf"$\text{{P}}(\lVert w_{{|S|}} \rVert - \lVert w_{{0}} \rVert )$")
        y_name = 'weight_norm_delta_p'
    else:
        ax.set(xlabel=plot_mode, ylabel=rf"$\text{{P}}(|b_{{|S|}}| - |b_{{0}}|)$")
        y_name = 'bias_delta_p'

    for model_key in PLOT_SELECTION:
        name = model_key
        model_result = model_results[model_key]

        df = model_result['df']
        df_group = df.groupby([f"id_{plot_mode}"], as_index=False).agg({y_name: ['mean', scipy.stats.sem]})
        df_reorder = df_group.loc[sort_idxs]

        x = list(range(len(df_group[f"id_{plot_mode}"])))
        y = df_reorder[y_name]['mean']
        y_err = df_reorder[y_name]['sem']
        label = legend_name_mapping[name]
        plt.plot(x, y, alpha=alphas[name], linestyle=linestyles[name], color=colors[name], linewidth=1.2,
                 label=label, )  # Takes avg over users

        y_low = y - y_err
        y_high = y + y_err
        plt.fill_between(x, y_low, y_high, alpha=0.2, color=colors[name], linewidth=0.0)

    ax.spines.right.set_visible(False)
    ax.spines.top.set_visible(False)

    leg = ax.legend()

    plt.tight_layout()
    path = os.path.join(parent_dirpath, filename)
    fig.savefig(path, bbox_inches='tight')

    # PLOT BOTH
    plt.show()
    plt.close('all')