In [1]:
import collections
import itertools
import json
import pathlib

from jax import tree_util
import matplotlib
from matplotlib import pyplot as plt
import ml_collections
import numpy as np
import pandas
import torch

import hier
import infer
import main
import metrics
import progmet

In [2]:
device = torch.device('cuda')

In [3]:
import configs.tiny_imagenet

base_config = configs.tiny_imagenet.get_config()
base_config.dataset_root = '/home/jack/data/manual/tiny_imagenet/'

_, eval_dataset, tree, node_names, _, eval_label_map = main.make_datasets(base_config)

In [4]:
eval_loader = torch.utils.data.DataLoader(
    dataset=eval_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=8)

In [5]:
# While metrics will be computed using the full tree, we will project
# the predictions to a subset of the tree before evaluating.

# Expect that "exact" will be zero but "correct" will be higher than usual.

def load_subtree(fname, tree, node_names):
    # We will evaluate the model using a sub-tree defined by a class subset.
    name_to_ind = {x: i for i, x in enumerate(node_names)}
    with open(fname) as f:
        min_subset_names = set([line.strip() for line in f])
    min_subset_inds = set([name_to_ind[name] for name in min_subset_names])
    return hier.subtree(tree, min_subset_inds)

subset_fname = 'resources/class_subset/tiny_imagenet_fiveai_max_leaf_size_20.txt'
subtree, node_subset, project_to_subtree = load_subtree(subset_fname, tree, node_names)
subtree_node_names = [node_names[i] for i in node_subset]

In [6]:
Experiment = collections.namedtuple('Experiment', ['config_file', 'model_file'])

def standard_experiment(experiment_dir, epoch):
    return Experiment(
        config_file=pathlib.Path(experiment_dir) / 'config.json',
        model_file=pathlib.Path(experiment_dir) / f'checkpoints/epoch-{epoch:04d}.pth')

In [7]:
experiments = {
    'full': standard_experiment(
        experiment_dir='/mnt/ssd1/projects/2022-01-hierarchical/experiments/2022-04-04-tiny-imagenet-cut/flat/',
        epoch=100),
    'min-leaf-5': standard_experiment(
        experiment_dir='/mnt/ssd1/projects/2022-01-hierarchical/experiments/2022-04-04-tiny-imagenet-cut/flat-max-leaf-size-20/',
        epoch=100),
}

In [8]:
markers = {
    'leaf': 's',
    'majority': 'o',
}

colors = dict(zip(
    experiments,
    map(matplotlib.cm.get_cmap('tab10'), itertools.count())))

In [9]:
# Perform inference in the sub-tree.

is_leaf = subtree.leaf_mask()
specificity = -subtree.num_leaf_descendants()
not_trivial = (subtree.num_children() != 1)

infer_fns = {
    'leaf': lambda p: infer.argmax_where(p, is_leaf),
    'majority': lambda p: infer.argmax_with_confidence(specificity, p, 0.5, not_trivial),
}

In [10]:
def apply_model(net, pred_fn, prob_subset, min_threshold):
    # Per-example predictions.

    outputs = {
        'gt': [],  # Node in hierarchy.
        'pred': {method: [] for method in infer_fns},
    }
    # Sequence-per-example predictions. Cannot be concatenated due to ragged shape.
    seq_outputs = {
        'pred': [],
        'prob': [],
    }

    net.eval()
    with torch.inference_mode():
        meter = progmet.ProgressMeter('apply', interval_time=5)
        for minibatch in meter(eval_loader):
            inputs, gt_labels = minibatch
            gt_node = eval_label_map.to_node[gt_labels]
            theta = net(inputs.to(device))
            prob = pred_fn(theta).cpu().numpy()

            # Perform inference in sub-tree.
            # Take subset of probability vector if necessary.
            prob = prob[..., prob_subset]

            pred = {}
            for method, infer_fn in infer_fns.items():
                pred[method] = infer_fn(prob)
            pred_seqs = [
                infer.pareto_optimal_predictions(specificity, p, min_threshold, not_trivial)
                for p in prob
            ]
            prob_seqs = [prob[i, pred_i] for i, pred_i in enumerate(pred_seqs)]

            # Correctness should not change whether evaluated in full tree or sub-tree.
            # (This may be violated if leaf nodes do not map to leaf nodes in the sub-tree.)
            full_gt = gt_node
            sub_gt = project_to_subtree[full_gt]
            sub_pred = pred['majority']
            full_pred = node_subset[sub_pred]
            assert np.all(
                metrics.correct(tree, full_gt, full_pred) ==
                metrics.correct(subtree, sub_gt, sub_pred))

            # Evaluate metrics in full tree.
            # Apply inverse projection to return to larger label space.
            pred = {method: node_subset[pred[method]] for method in pred}
            pred_seqs = [node_subset[seq] for seq in pred_seqs]

            # The alternative would be to return gt and pred in the sub-tree.
            # To do this, we must project the ground-truth.
            # gt_node = project_to_subtree[gt_node]

            outputs['gt'].append(gt_node)
            for method in infer_fns:
                outputs['pred'][method].append(pred[method])
            seq_outputs['pred'].extend(pred_seqs)
            seq_outputs['prob'].extend(prob_seqs)

    # Concatenate results from minibatches.
    leaf_predicate = lambda x: not isinstance(x, dict)  # Treat lists as values, not containers.
    outputs = tree_util.tree_map(np.concatenate, outputs, is_leaf=leaf_predicate)

    return outputs, seq_outputs

In [11]:
def assess_predictions(tree, outputs, seq_outputs):
    gt = outputs['gt']
    pred = outputs['pred']
    pred_seq = seq_outputs['pred']
    prob_seq = seq_outputs['prob']

    info_metric = metrics.UniformLeafInfoMetric(tree)
    depth_metric = metrics.DepthMetric(tree)
    metric_fns = {
        'exact': lambda gt, pr: pr == gt,
        'correct': metrics.IsCorrect(tree),
        'info_excess': info_metric.excess,
        'info_deficient': info_metric.deficient,
        'info_dist': info_metric.dist,
        'info_recall': info_metric.recall,
        'info_precision': info_metric.precision,
        'depth_excess': depth_metric.excess,
        'depth_deficient': depth_metric.deficient,
        'depth_dist': depth_metric.dist,
        'depth_recall': depth_metric.recall,
        'depth_precision': depth_metric.precision,
    }

    # Evaluate predictions for each method.
    pred = {
        method: hier.truncate_at_lca(tree, gt, pr)
        for method, pr in pred.items()
    }
    pred_metrics = {
        method: {field: np.mean(metric_fn(gt, pr))
                 for field, metric_fn in metric_fns.items()}
        for method, pr in pred.items()
    }

    # Evaluate predictions in Pareto sequence.
    find_lca = hier.FindLCA(tree)
    pred_seq = [hier.truncate_given_lca(gt_i, pr_i, find_lca(gt_i, pr_i)) for gt_i, pr_i in zip(gt, pred_seq)]
    metric_values_seq = {
        field: [metric_fn(gt_i, pr_i) for gt_i, pr_i in zip(gt, pred_seq)]
        for field, metric_fn in metric_fns.items()
    }
    pareto_scores, pareto_totals = metrics.operating_curve(prob_seq, metric_values_seq)
    pareto_means = {k: v / len(gt) for k, v in pareto_totals.items()}

    return pred_metrics, pareto_scores, pareto_means

In [12]:
results = {}

In [13]:
for name, (config_file, model_file) in experiments.items():
    if name in results:
        print('cached:', name)
        continue

    # Load model.
    with open(config_file, 'r') as f:
        config = ml_collections.ConfigDict(json.load(f))

    # Depending on which tree was used during training,
    # may need to extract subset of probability vector.
    _, _, train_tree, train_node_names, _, _ = main.make_datasets(config)
    # Find location of the subtree nodes in the train nodes.
    # These will be the indices that we extract from the probability vector.
    prob_subset = [train_node_names.index(node) for node in subtree_node_names]

    num_outputs = main.get_num_outputs(config.predict, train_tree)
    net = main.make_model(config.model, num_outputs)
    net.load_state_dict(torch.load(model_file), strict=True)

    net.to(device)
    _, pred_fn = main.make_loss(config, train_tree, device)
    outputs, seq_outputs = apply_model(net, pred_fn, prob_subset, min_threshold=0.1)
    pred_metrics, pareto_scores, pareto_metrics = assess_predictions(tree, outputs, seq_outputs)
    # pred_metrics, pareto_scores, pareto_metrics = assess_predictions(subtree, outputs, seq_outputs)

    results[name] = {
        'pred_metrics': pred_metrics,
        'pareto_scores': pareto_scores,
        'pareto_metrics': pareto_metrics,
    }

apply: 100% (40/40); T=0.0167 f=59.9; mean T=0.0167 f=59.9; elapsed 0:00:01
apply: 100% (40/40); T=0.0151 f=66.0; mean T=0.0151 f=66.0; elapsed 0:00:01


In [14]:
pandas.DataFrame.from_dict({
    (name, method): results[name]['pred_metrics'][method]
    for name in results
    for method in results[name]['pred_metrics']
}, orient='index')

Unnamed: 0,Unnamed: 1,exact,correct,info_excess,info_deficient,info_dist,info_recall,info_precision,depth_excess,depth_deficient,depth_dist,depth_recall,depth_precision
full,leaf,0.0108,0.6134,1.456932,4.527882,5.984814,0.407644,0.68952,0.827,3.8067,4.6337,0.520514,0.829561
full,majority,0.009,0.7341,0.765205,4.778567,5.543772,0.374849,0.792605,0.5305,3.8982,4.4287,0.509378,0.878995
min-leaf-5,leaf,0.0094,0.5938,1.505366,4.629045,6.134411,0.39441,0.673787,0.8776,3.8509,4.7285,0.514342,0.818051
min-leaf-5,majority,0.0087,0.647,1.197004,4.700515,5.89752,0.38506,0.718975,0.7461,3.8685,4.6146,0.512079,0.840903


In [15]:
def plot_metrics(x, y):
    for name in results:
        pred_metrics = results[name]['pred_metrics']
        pareto_scores = results[name]['pareto_scores']
        pareto_metrics = results[name]['pareto_metrics']
        ge = np.concatenate(([True], pareto_scores >= 0.5))
        le = np.concatenate(([False], pareto_scores <= 0.5))
        plt.plot(pareto_metrics[x][ge], pareto_metrics[y][ge],
                 color=colors[name], label=name)
        plt.plot(pareto_metrics[x][le], pareto_metrics[y][le],
                 color=colors[name], linestyle='--')
        for method, method_metrics in pred_metrics.items():
            plt.plot(method_metrics[x], method_metrics[y], color=colors[name],
                     marker=markers[method], markerfacecolor='none')
    plt.ylim(top=1)
    plt.xlim(left=0)
    # plt.axis('equal')
    # plt.gca().set_aspect(1)
    plt.grid()
    plt.xlabel(x)
    plt.ylabel(y)
    plt.legend()

In [None]:
plot_metrics('depth_recall', 'correct')

In [None]:
plot_metrics('depth_recall', 'depth_precision')

In [None]:
plot_metrics('info_recall', 'info_precision')