In [31]:
from functools import partial
import pathlib

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import benchmark

In [2]:
import sys

# TODO: pip install -e hier
PROJECT_ROOT = pathlib.Path('.').absolute().parent
sys.path.insert(0, str(PROJECT_ROOT))

import hier
import hier_torch

In [3]:
RESOURCE_DIR = pathlib.Path('../resources')
with open(RESOURCE_DIR / 'hierarchy/imagenet_fiveai.csv') as f:
    edges = hier.load_edges(f)
tree, names = hier.make_hierarchy_from_edges(edges)

In [4]:
tree.num_leaf_nodes(), tree.num_nodes()

(1000, 1372)

In [5]:
b = 128
device = torch.device('cuda')

In [6]:
num_nodes = tree.num_nodes()
num_edges = tree.num_nodes() - 1
num_internal = tree.num_internal_nodes()
num_leaf = tree.num_leaf_nodes()
max_children = max(map(len, tree.children().values()))

In [7]:
labels = torch.randint(tree.num_leaf_nodes(), size=(64,)).to(device)

In [21]:
# Benchmark the descendant softmax for reference.

scores = torch.randn((*labels.shape, tree.num_nodes())).to(device)

loss_fn = hier_torch.MaxCutSoftmaxLoss(tree, with_leaf_targets=True, reduction='mean').to(device)
loss_fn(scores, labels)

timer = benchmark.Timer(
    stmt='loss_fn(scores, labels)',
    globals=dict(
        scores=scores,
        labels=labels,
        loss_fn=loss_fn,
    ))
timer.blocked_autorange()

<torch.utils.benchmark.utils.common.Measurement object at 0x7fcf651b02e0>
loss_fn(scores, labels)
  2.85 ms
  1 measurement, 100 runs , 1 thread

In [22]:
# Benchmark the descendant softmax for reference.

scores = torch.randn((*labels.shape, tree.num_leaf_nodes())).to(device)

loss_fn = hier_torch.FlatBertinettoHXE(tree, alpha=0.1, with_leaf_targets=True).to(device)
loss_fn(scores, labels)

timer = benchmark.Timer(
    stmt='loss_fn(scores, labels)',
    globals=dict(
        scores=scores,
        labels=labels,
        loss_fn=loss_fn,
    ))
timer.blocked_autorange()

<torch.utils.benchmark.utils.common.Measurement object at 0x7fcf6a698eb0>
loss_fn(scores, labels)
  3.06 ms
  1 measurement, 100 runs , 1 thread

In [23]:
# Benchmark the softmax-margin loss.

scores = torch.randn((*labels.shape, tree.num_nodes())).to(device)

loss_fn = hier_torch.MarginLoss(tree, margin='incorrect', tau=5.0, hardness='soft', with_leaf_targets=True).to(device)
loss_fn(scores, labels)

timer = benchmark.Timer(
    stmt='loss_fn(scores, labels)',
    globals=dict(
        scores=scores,
        labels=labels,
        loss_fn=loss_fn,
    ))
timer.blocked_autorange()

<torch.utils.benchmark.utils.common.Measurement object at 0x7fcf6cd39430>
loss_fn(scores, labels)
  Median: 98.34 us
  3 measurements, 1000 runs per measurement, 1 thread

In [33]:
# Benchmark the flat-softmax loss (with possibility of internal labels).

scores = torch.randn((*labels.shape, tree.num_leaf_nodes())).to(device)

label_nodes = torch.from_numpy(tree.leaf_subset()).to(device)[labels]

loss_fn = hier_torch.FlatSoftmaxNLL(tree).to(device)
loss_fn(scores, label_nodes)

timer = benchmark.Timer(
    stmt='loss_fn(scores, label_nodes)',
    globals=dict(
        scores=scores,
        label_nodes=label_nodes,
        loss_fn=loss_fn,
    ))
timer.blocked_autorange()

<torch.utils.benchmark.utils.common.Measurement object at 0x7fce2fe79c40>
loss_fn(scores, label_nodes)
  Median: 10.07 ms
  3 measurements, 10 runs per measurement, 1 thread

In [34]:
class FlatSoftmaxNLL(nn.Module):
    """Like cross_entropy() but supports internal labels."""

    def __init__(self, tree):
        super().__init__()
        # The value is_ancestor[i, j] is true if node i is an ancestor of node j.
        is_ancestor = tree.ancestor_mask(strict=False)
        # The value is_ancestor_leaf[i, k] is true if node i is an ancestor of leaf k.
        node_to_leaf_mask = is_ancestor[:, tree.leaf_mask()]
        self.node_to_leaf_mask = torch.from_numpy(node_to_leaf_mask)

    def _apply(self, fn):
        super()._apply(fn)
        self.node_to_leaf_mask = fn(self.node_to_leaf_mask)
        return self

    def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        logp_leaf = F.log_softmax(scores, dim=-1)
        # Obtain logp for leaf descendants, -inf for other nodes.
        label_leaf_mask = self.node_to_leaf_mask[labels, :]
        inf = torch.tensor(torch.inf, device=scores.device)
        logp_descendants = torch.where(label_leaf_mask, logp_leaf, -inf)
        logp_label = torch.logsumexp(logp_descendants, dim=-1)
        return torch.mean(-logp_label)

In [35]:
# Compare.

scores = torch.randn((*labels.shape, tree.num_leaf_nodes())).to(device)

label_nodes = torch.from_numpy(tree.leaf_subset()).to(device)[labels]

loss_fn = FlatSoftmaxNLL(tree).to(device)
loss_fn(scores, label_nodes)

timer = benchmark.Timer(
    stmt='loss_fn(scores, label_nodes)',
    globals=dict(
        scores=scores,
        label_nodes=label_nodes,
        loss_fn=loss_fn,
    ))
timer.blocked_autorange()

<torch.utils.benchmark.utils.common.Measurement object at 0x7fcf6cd9daf0>
loss_fn(scores, label_nodes)
  206.47 us
  1 measurement, 1000 runs , 1 thread