In [1]:
from functools import partial
import pathlib

import numpy as np
import torch
from torch import nn
from torch.utils import benchmark

In [2]:
import sys

# TODO: pip install -e hier
PROJECT_ROOT = pathlib.Path('.').absolute().parent
sys.path.insert(0, str(PROJECT_ROOT))

import hier

In [3]:
RESOURCE_DIR = pathlib.Path('../resources')
with open(RESOURCE_DIR / 'hierarchy/imagenet_fiveai.csv') as f:
    edges = hier.load_edges(f)
tree, names = hier.make_hierarchy_from_edges(edges)

In [4]:
tree.num_leaf_nodes(), tree.num_nodes()

(1000, 1372)

In [5]:
b = 1024
device = torch.device('cuda')

In [6]:
num_nodes = tree.num_nodes()
num_edges = tree.num_nodes() - 1
num_internal = tree.num_internal_nodes()
num_leaf = tree.num_leaf_nodes()
max_children = max(map(len, tree.children().values()))

In [7]:
with_leaf_targets = True

In [8]:
node_depth = tree.depths()
is_leaf = tree.leaf_mask()
max_depth = np.max(node_depth)
num_levels = max_depth
level_depth = np.arange(1, max_depth + 1)
level_masks = ((level_depth[:, None] == node_depth) |
               ((level_depth[:, None] > node_depth) & is_leaf))
level_sizes = np.count_nonzero(level_masks, axis=1)
level_nodes = [np.flatnonzero(mask) for mask in level_masks]

# Construct map from nodes to index within softmax at each level.
node_to_level_target = np.full([num_levels, tree.num_nodes()], -1, dtype=int)
for i in range(num_levels):
    node_to_level_target[i, level_nodes[i]] = np.arange(level_sizes[i])

if with_leaf_targets:
    label_to_node = tree.leaf_subset()
else:
    label_to_node = np.arange(tree.num_nodes())

paths = tree.paths_padded(method='self', exclude_root=True)
label_to_level_target = np.full([num_levels, len(label_to_node)], -1, dtype=int)
for i in range(num_levels):
    label_to_level_target[i, :] = node_to_level_target[i, paths[label_to_node, i]]
    # Every label should give a valid target.
    # Note that this will fail for non-leaf targets.
    # TODO: Implement no loss descendants of label?
    assert np.all(label_to_level_target[i, :] >= 0)
    assert np.all(label_to_level_target < level_sizes[i])

In [9]:
level_sizes.tolist()

[2, 5, 22, 47, 112, 240, 403, 587, 736, 819, 887, 921, 977, 999, 1000]

In [10]:
# Vectorized padded softmax.

theta = torch.randn([b, num_levels, max(level_sizes)]).to(device)
for i, size in enumerate(level_sizes):
    theta[i, size:] = -torch.inf

timer = benchmark.Timer(
    stmt='torch.nn.functional.log_softmax(theta, dim=-1)',
    globals={'theta': theta})
timer.blocked_autorange()

<torch.utils.benchmark.utils.common.Measurement object at 0x7f7148fbedf0>
torch.nn.functional.log_softmax(theta, dim=-1)
  Median: 536.40 us
  IQR:    65.29 us (505.95 to 571.25)
  4 measurements, 100 runs per measurement, 1 thread
           This could indicate system fluctuation.

In [11]:
# Individual softmaxes.

theta = [torch.randn([b, size]).to(device) for size in level_sizes]

timer = benchmark.Timer(
    stmt='[torch.nn.functional.log_softmax(theta_i, dim=-1) for theta_i in theta]',
    globals={'theta': theta})
timer.blocked_autorange()

<torch.utils.benchmark.utils.common.Measurement object at 0x7f7149001a30>
[torch.nn.functional.log_softmax(theta_i, dim=-1) for theta_i in theta]
  344.14 us
  1 measurement, 1000 runs , 1 thread