In [10]:
from functools import partial
import pathlib

import numpy as np
import torch
from torch import nn
from torch.utils import benchmark
import torch.nn.functional as F

In [28]:
import sys

# TODO: pip install -e hier
PROJECT_ROOT = pathlib.Path('.').absolute().parent
sys.path.insert(0, str(PROJECT_ROOT))

import hier
import hier_torch

In [17]:
RESOURCE_DIR = pathlib.Path('../resources')
with open(RESOURCE_DIR / 'hierarchy/imagenet_fiveai.csv') as f:
    edges = hier.load_edges(f)
tree, names = hier.make_hierarchy_from_edges(edges)

In [18]:
tree.num_leaf_nodes(), tree.num_nodes()

(1000, 1372)

In [19]:
b = 1024
device = torch.device('cuda')

In [22]:
num_nodes = tree.num_nodes()
num_edges = tree.num_nodes() - 1
num_internal = tree.num_internal_nodes()
num_leaf = tree.num_leaf_nodes()
max_children = max(map(len, tree.children().values()))

In [61]:
# Benchmark a vanilla softmax function, for reference.

theta = torch.randn([b, num_nodes - 1]).to(device)
labels = torch.randint(2, size=[b, num_nodes - 1]).float().to(device)

timer = benchmark.Timer(
    stmt='loss(theta, labels)',
    globals={
        'theta': theta,
        'labels': labels,
        'loss': nn.BCEWithLogitsLoss(reduction='mean'),
    })
timer.blocked_autorange()

<torch.utils.benchmark.utils.common.Measurement object at 0x7f22e84182e0>
loss(theta, labels)
  291.31 us
  1 measurement, 1000 runs , 1 thread

In [53]:
theta = torch.randn([b, num_nodes - 1]).to(device)
labels = torch.randint(tree.num_leaf_nodes(), size=[b]).to(device)

timer = benchmark.Timer(
    stmt='loss(theta, labels, dim=-1)',
    globals={
        'theta': theta,
        'labels': labels,
        'loss': hier_torch.MultiLabelNLL(tree, with_leaf_targets=True).to(device),
    })
timer.blocked_autorange()

<torch.utils.benchmark.utils.common.Measurement object at 0x7f22e8465a60>
loss(theta, labels, dim=-1)
  443.80 us
  1 measurement, 1000 runs , 1 thread

In [63]:
# Compare to cross-entropy loss: much faster!

theta = torch.randn([b, num_nodes - 1]).to(device)
labels = torch.randint(num_nodes - 1, size=[b]).to(device)

timer = benchmark.Timer(
    stmt='loss(theta, labels)',
    globals={
        'theta': theta,
        'labels': labels,
        'loss': nn.CrossEntropyLoss(),
    })
timer.blocked_autorange()

<torch.utils.benchmark.utils.common.Measurement object at 0x7f22e8459490>
loss(theta, labels)
  Median: 44.58 us
  IQR:    15.27 us (42.68 to 57.95)
  4 measurements, 1000 runs per measurement, 1 thread
           This suggests significant environmental influence.

In [24]:
theta = torch.randn([b, num_nodes - 1]).to(device)

timer = benchmark.Timer(
    stmt='torch.nn.functional.logsigmoid(theta)',
    globals={'theta': theta})
timer.blocked_autorange()

<torch.utils.benchmark.utils.common.Measurement object at 0x7f24083959a0>
torch.nn.functional.logsigmoid(theta)
  33.95 us
  1 measurement, 10000 runs , 1 thread

In [40]:
# Benchmark a vanilla softmax function, for reference.

theta = torch.randn([b, num_nodes - 1]).to(device)

timer = benchmark.Timer(
    stmt='f(tree, theta, dim=-1)',
    globals={
        'tree': tree,
        'theta': theta,
        'f': hier_torch.multilabel_likelihood,
    })
timer.blocked_autorange()

<torch.utils.benchmark.utils.common.Measurement object at 0x7f23fca0be20>
f(tree, theta, dim=-1)
  Median: 58.65 us
  3 measurements, 1000 runs per measurement, 1 thread
           This suggests significant environmental influence.