In [None]:
import collections
import itertools
import pathlib
import pickle
from typing import Optional, Sequence, Tuple, Union

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import hier
import main
import metrics

In [None]:
Experiment = collections.namedtuple('Experiment', ['dir', 'epoch'])

In [None]:
# TINY IMAGENET
experiment_root = pathlib.Path('/home/jack/projects/2022-01-hierarchical/experiments/2022-02-28-tiny-incomplete-lr0.1-b256/')
with open('resources/hierarchy/tiny_imagenet_fiveai.csv') as f:
    tree, names = hier.make_hierarchy_from_edges(hier.load_edges(f))
experiments = {
    'flat': Experiment(experiment_root / 'complete-flat', 100),
    'hier': Experiment(experiment_root / 'complete-hier', 100),
    'hxe-0.1': Experiment(experiment_root / 'complete-hxe-0.1', 100),
    'hxe-0.2': Experiment(experiment_root / 'complete-hxe-0.2', 100),
    'hxe-0.5': Experiment(experiment_root / 'complete-hxe-0.5', 100),
    # 'multilabel': experiment_root / 'complete-multilabel-sum',
}
    
# # TINY IMAGENET - HIERARCHY SUBSET
# with open('resources/hierarchy/tiny_imagenet_fiveai.csv') as f:
#     tree, names = hier.make_hierarchy_from_edges(hier.load_edges(f))
# with open('resources/hierarchy/tiny_imagenet_fiveai_incomplete_subset.txt') as f:
#     name_subset = list(map(str.strip, f.readlines()))
# name_to_node = {name: i for i, name in enumerate(names)}
# node_subset = [name_to_node[name] for name in name_subset]
# tree, _, _ = hier.subtree(tree, node_subset)
# experiment_dirs = {
#     'flat': experiment_root / 'incomplete-flat',
#     'hier': experiment_root / 'incomplete-hier',
# }

# CONDITIONAL NORMALIZATION EXPERIMENT - TINY IMAGENET
# experiment_root = pathlib.Path(
#     '/home/jack/projects/2022-01-hierarchical/experiments/2022-02-28-tiny-imagenet/')
# experiments = {
#     'flat': Experiment(experiment_root / 'flat', 100),
#     'hier': Experiment(experiment_root / 'hier', 100),
#     'hier-norm-parent': Experiment(experiment_root / 'hier-norm-parent', 100),
#     'hier-norm-self': Experiment(experiment_root / 'hier-norm-self', 100),
# }
# with open('resources/hierarchy/tiny_imagenet_fiveai.csv') as f:
#     tree, names = hier.make_hierarchy_from_edges(hier.load_edges(f))

# # IMAGENET
# experiment_root = pathlib.Path(
#     '/home/jack/projects/2022-01-hierarchical/experiments/2022-03-05-imagenet-acc/')
# with open('resources/hierarchy/imagenet_fiveai.csv') as f:
#     tree, names = hier.make_hierarchy_from_edges(hier.load_edges(f))
# experiments = {
#     'flat': Experiment(experiment_root / 'flat-lr-0.01-wd-0.0003-ep-15', 15),
#     'hier': Experiment(experiment_root / 'hier-lr-0.01-wd-0.0003-ep-15-b-64', 15),
# }

# # INATURALIST 2018
# experiment_root = pathlib.Path(
#     '/home/jack/projects/2022-01-hierarchical/experiments/2022-03-01-inat18-acc/')
# with open('resources/hierarchy/inat18.csv') as f:
#     tree, names = hier.make_hierarchy_from_edges(hier.load_edges(f))
# experiments = {
#     'flat-finetune': Experiment(experiment_root / '400px-finetune-lr-0.1-wd-0.0001-ep-20', 20),
#     'hier-finetune': Experiment(experiment_root / '400px-hier-finetune-lr-0.1-wd-0.0001-ep-20', 20),
#     'flat-scratch': Experiment(experiment_root / 'scratch-lr-0.1-wd-0.0001-ep-20', 20),
# #     'hier-finetune': Experiment(experiment_root / '400px-hier-finetune-lr-0.1-wd-0.0001-ep-20', 20),
# #     'hxe-0.1': experiment_root / 'complete-hxe-0.1',
# #     'hxe-0.2': experiment_root / 'complete-hxe-0.2',
# #     'hxe-0.5': experiment_root / 'complete-hxe-0.5',
# #     'multilabel': experiment_root / 'complete-multilabel-sum',
# }


# # COARSE LABELS EXPERIMENT - TINY-IMAGENET, RESNET18
# experiment_root = pathlib.Path(
#     '/home/jack/projects/2022-01-hierarchical/experiments/2022-03-10-tiny-imagenet-coarsen/')
# with open('resources/hierarchy/tiny_imagenet_fiveai.csv') as f:
#     tree, names = hier.make_hierarchy_from_edges(hier.load_edges(f))
# experiments = {
#     k: Experiment(experiment_root / ('coarsen-' + k), 100) for k in [
#         'beta-2-1-flat', 'beta-1-1-flat', # 'beta-1-2-flat',
#         'beta-2-1-hier', 'beta-1-1-hier', # 'beta-1-2-hier',
#     ]
# }

In [None]:
def load_results(experiment_dir, epoch):
    epoch_str = '{:04d}'.format(epoch)
    outputs_file = experiment_dir / f'predictions/output-epoch-{epoch_str}.pkl'
    # operating_points_file = experiment_dir / f'predictions/operating-points-epoch-{epoch_str}.pkl'
    full_outputs_file = experiment_dir / f'predictions/full-output-epoch-{epoch_str}.pkl'

    with open(outputs_file, 'rb') as f:
        outputs = pickle.load(f)
    with open(full_outputs_file, 'rb') as f:
        full_outputs = pickle.load(f)

    outputs.update(full_outputs)
    return outputs

In [None]:
outputs = load_results(*experiments[sorted(experiments.keys())[0]])

In [None]:
# Re-normalize.
# outputs['prob'] = outputs['prob'] / outputs['prob'][:, [0]]

In [None]:
specificity = -tree.num_leaf_descendants()

In [None]:
# Minimize number of leaf descendants then maximize depth.
node_mask = (tree.num_children() != 1)
pred_seqs = [
    main.prediction_sequence(specificity, p, threshold=0.5, condition=node_mask)
    for p in outputs['prob']
]
prob_seqs = [outputs['prob'][i, pred_i] for i, pred_i in enumerate(pred_seqs)]

In [None]:
is_leaf = tree.leaf_mask()
most_specific_pred = np.array([pr[-1] for pr in pred_seqs])
np.all(is_leaf[most_specific_pred])

In [None]:
# Truncate at LCA.
find_lca_fn = hier.FindLCA(tree)
lca_seqs = [find_lca_fn(gt, pr) for gt, pr in zip(outputs['gt'], pred_seqs)]
pred_seqs = [hier.truncate_given_lca(gt, pr, lca)
             for gt, pr, lca in zip(outputs['gt'], pred_seqs, lca_seqs)]
# TODO: Remove redundant elements from pred_seqs and scores?
# How many examples were truncated?
len([pr for pr in pred_seqs if len(set(pr)) != len(pr)]) / len(pred_seqs)

In [None]:
info_metric = metrics.UniformLeafInfoMetric(tree)
depth_metric = metrics.DepthMetric(tree)
# depth_metric = metrics.LCAMetric(tree, tree.depths() + 1)
metric_fns = {
    'exact': lambda gt, pr: pr == gt,
    'correct': metrics.IsCorrect(tree),
    'info_deficient': info_metric.deficient,
    'info_excess': info_metric.excess,
    'info_recall': info_metric.recall,
    'info_precision': info_metric.precision,
    'info_f1': info_metric.f1,
    'info_dist': info_metric.dist,
    'info_lca': info_metric.value_at_lca,
    'info_gt': info_metric.value_at_gt,
    'info_pr': info_metric.value_at_pr,
    'depth_deficient': depth_metric.deficient,
    'depth_excess': depth_metric.excess,
    'depth_recall': depth_metric.recall,
    'depth_precision': depth_metric.precision,
    'depth_f1': depth_metric.f1,
    'depth_dist': depth_metric.dist,
    'depth_lca': depth_metric.value_at_lca,
    'depth_gt': depth_metric.value_at_gt,
    'depth_pr': depth_metric.value_at_pr,
}

In [None]:
metric_seqs = {}
for field in metric_fns:
    metric_fn = metric_fns[field]
    # TODO: Could vectorize if necessary.
    metric_seqs[field] = [metric_fn(gt, pr) for gt, pr in zip(outputs['gt'], pred_seqs)]

In [None]:
totals, thresholds = main.pool_operating_points(prob_seqs, metric_seqs)

In [None]:
_, first_index = np.unique(-thresholds, return_index=True)
last_index = np.concatenate([first_index[1:], [len(thresholds)]]) - 1
thresholds = thresholds[last_index]
totals = {k: v[np.concatenate([[0], last_index])] for k, v in totals.items()}

In [None]:
means = {k: v / len(outputs['gt']) for k, v in totals.items()}

In [None]:
pred_leaf = main.argmax_where(outputs['prob'], tree.leaf_mask())
metrics_leaf = {field: np.mean(metric_fn(outputs['gt'], pred_leaf)) for field, metric_fn in metric_fns.items()}

In [None]:
# Evaluate predictions using majority rule.
# Maximize specificity, then maximize confidence.
pred_maj = main.arglexmin_where(
    np.broadcast_arrays(-outputs['prob'], -specificity),
    (outputs['prob'] > 0.5) & node_mask)
metrics_maj = {field: np.mean(metric_fn(outputs['gt'], pred_maj)) for field, metric_fn in metric_fns.items()}

In [None]:
# means['info_recall'] = means['info_lca'] / means['info_gt']
# means['info_precision'] = means['info_lca'] / means['info_pr']

# # Add 1 to depth to count root node.
# means['depth_recall'] = means['depth_lca'] / means['depth_gt']
# means['depth_precision'] = means['depth_lca'] / means['depth_pr']

# means['depth_recall'] = (means['depth_lca'] + 1) / (means['depth_gt'] + 1)
# means['depth_precision'] = (means['depth_lca'] + 1) / (means['depth_pr'] + 1)

In [None]:
def plot_metrics(x, y):
    plt.plot(means[x], means[y])
    plt.plot(metrics_maj[x], metrics_maj[y], marker='o')
    plt.plot(metrics_leaf[x], metrics_leaf[y], marker='s')
    plt.xlabel(x)
    plt.ylabel(y)
    plt.grid()
    plt.show()

In [None]:
plot_metrics('exact', 'correct')

In [None]:
plot_metrics('info_deficient', 'info_excess')

In [None]:
plot_metrics('depth_recall', 'depth_precision')

In [None]:
def add_metrics(means):
    means = dict(means)
#     means['info_recall'] = means['info_lca'] / means['info_gt']
#     means['info_precision'] = means['info_lca'] / means['info_pr']
#     means['depth_recall'] = means['depth_lca'] / means['depth_gt']
#     means['depth_precision'] = means['depth_lca'] / means['depth_pr']
    return means

In [None]:
risk_matrix = {
    name: metric_fn(
        np.arange(tree.num_nodes())[:, None],
        tree.leaf_subset()[None, :])
    for name, metric_fn in metric_fns.items()
    if '_dist' in name or '_f1' in name
}

In [None]:
def predict_and_evaluate(outputs):
    pred = {}
    pred['leaf'] = main.argmax_where(outputs['prob'], tree.leaf_mask(), axis=-1)
    pred['majority'] = main.arglexmin_where(
        np.broadcast_arrays(-outputs['prob'], -specificity),
        (outputs['prob'] > 0.5) & node_mask,
        axis=-1)

    pred['crm_info_f1'] = np.argmax(
        np.dot(outputs['prob'][:, tree.leaf_mask()], risk_matrix['info_f1'].T),
        axis=-1)
    pred['crm_info_dist'] = np.argmin(
        np.dot(outputs['prob'][:, tree.leaf_mask()], risk_matrix['info_dist'].T),
        axis=-1)
    pred['crm_depth_f1'] = np.argmax(
        np.dot(outputs['prob'][:, tree.leaf_mask()], risk_matrix['depth_f1'].T),
        axis=-1)
    pred['crm_depth_dist'] = np.argmin(
        np.dot(outputs['prob'][:, tree.leaf_mask()], risk_matrix['depth_dist'].T),
        axis=-1)

    # Truncate at LCA.
    pred = {
        method: hier.truncate_given_lca(
            outputs['gt'], pred[method],
            hier.lca(tree, outputs['gt'], pred[method]))
        for method in pred
    }
    means = {
        method: add_metrics({
            field: np.mean(fn(outputs['gt'], pred[method]))
            for field, fn in metric_fns.items()
        })
        for method in pred
    }
    return means

In [None]:
predict_and_evaluate(outputs)

In [None]:
def construct_curves(outputs, min_threshold=None):
    pred_seqs = [main.prediction_sequence(specificity, p, threshold=min_threshold, condition=node_mask) for p in outputs['prob']]
    prob_seqs = [outputs['prob'][i, pred_i] for i, pred_i in enumerate(pred_seqs)]

    # Truncate at LCA.
    find_lca_fn = hier.FindLCA(tree)
    lca_seqs = [find_lca_fn(gt, pr) for gt, pr in zip(outputs['gt'], pred_seqs)]
    pred_seqs = [hier.truncate_given_lca(gt, pr, lca)
                 for gt, pr, lca in zip(outputs['gt'], pred_seqs, lca_seqs)]
    
    metric_seqs = {
        field: [fn(gt, pr) for gt, pr in zip(outputs['gt'], pred_seqs)]
        for field, fn in metric_fns.items()
    }

    totals, thresholds = main.pool_operating_points(prob_seqs, metric_seqs)

    _, first_index = np.unique(-thresholds, return_index=True)
    last_index = np.concatenate([first_index[1:], [len(thresholds)]]) - 1
    thresholds = thresholds[last_index]
    totals = {k: v[np.concatenate([[0], last_index + 1])] for k, v in totals.items()}

    means = {k: v / len(outputs['gt']) for k, v in totals.items()}
    means = add_metrics(means)
    return means, thresholds

In [None]:
min_threshold = 0.5
# min_threshold = 2 / tree.num_leaf_nodes()
# min_threshold = None
curves = {}
points = {}
for name, experiment in experiments.items():
    outputs = load_results(*experiment)
    curves[name], _ = construct_curves(outputs, min_threshold)
    points[name] = predict_and_evaluate(outputs)
    print('frac with a leaf node > 0.5:', name,
          np.mean(np.max(outputs['prob'][:, tree.leaf_mask()], axis=-1) > 0.5))

In [None]:
MARKERS = {
    'leaf': 's',
    'majority': 'o',
}

COLORS = dict(zip(experiments,
                  map(matplotlib.cm.get_cmap('tab10'), itertools.count())))

In [None]:
def plot_metrics(x, y, extra=None):
    for name in curves:
        plt.plot(curves[name][x], curves[name][y], label=name,
                 color=COLORS[name])
        for method in ['majority', 'leaf']:
            plt.plot(points[name][method][x], points[name][method][y],
                     marker=MARKERS[method], color=COLORS[name], markerfacecolor='none')
        if extra is not None:
            plt.plot(points[name][extra][x], points[name][extra][y],
                     marker='*', color=COLORS[name])
    plt.xlabel(x)
    plt.ylabel(y)
    plt.grid()
    plt.legend()
    plt.show()

In [None]:
plot_metrics('exact', 'correct')

In [None]:
plot_metrics('depth_deficient', 'depth_excess', 'crm_depth_dist')

In [None]:
plot_metrics('info_deficient', 'info_excess', 'crm_info_dist')

In [None]:
plot_metrics('info_recall', 'info_precision', 'crm_info_f1')

In [None]:
plot_metrics('depth_recall', 'depth_precision', 'crm_depth_f1')

In [None]:
plot_metrics('exact', 'depth_dist', 'crm_depth_dist')

In [None]:
plot_metrics('depth_recall', 'correct')

In [None]:
plot_metrics('exact', 'info_dist', 'crm_info_dist')

In [None]:
plot_metrics('info_deficient', 'info_excess', 'crm_info_dist')