In [1]:
from functools import partial
import pathlib

import numpy as np
import torch
from torch import nn
from torch.utils import benchmark

In [2]:
import sys

# TODO: pip install -e hier
PROJECT_ROOT = pathlib.Path('.').absolute().parent
sys.path.insert(0, str(PROJECT_ROOT))

import hier

In [3]:
RESOURCE_DIR = pathlib.Path('../resources')
with open(RESOURCE_DIR / 'hierarchy/imagenet_fiveai.csv') as f:
    edges = hier.load_edges(f)
tree, names = hier.make_hierarchy_from_edges(edges)

In [4]:
tree.num_leaf_nodes(), tree.num_nodes()

(1000, 1372)

In [5]:
b = 1024
device = torch.device('cuda')

In [6]:
num_nodes = tree.num_nodes()
num_edges = tree.num_nodes() - 1
num_internal = tree.num_internal_nodes()
num_leaf = tree.num_leaf_nodes()
max_children = max(map(len, tree.children().values()))

In [7]:
# Benchmark a vanilla softmax function, for reference.

theta = torch.randn([b, num_edges]).to(device)

timer = benchmark.Timer(
    stmt='torch.log_softmax(theta, dim=-1)',
    globals={'theta': theta})
timer.blocked_autorange()

<torch.utils.benchmark.utils.common.Measurement object at 0x7f9c70d452e0>
torch.log_softmax(theta, dim=-1)
  20.73 us
  1 measurement, 10000 runs , 1 thread

In [8]:
# Benchmark cat . log_softmax . split.
# This is the minimum required for softmaxes of different sizes.

node_to_children = tree.children()
cond_children = [node_to_children[x] for x in tree.internal_subset()]
cond_sizes = [len(x) for x in cond_children]

def cat_log_softmax_split(scores, dim=-1):
    cond_scores = scores.split(cond_sizes, dim=dim)
    cond_log_softmax = [x.log_softmax(dim=dim) for x in cond_scores]
    return torch.cat(cond_log_softmax, dim=dim)

theta = torch.randn([b, num_edges]).to(device)

timer = benchmark.Timer(
    stmt='cat_log_softmax_split(theta, dim=-1)',
    globals={
        'theta': theta,
        'cat_log_softmax_split': cat_log_softmax_split,
    })
timer.blocked_autorange()

<torch.utils.benchmark.utils.common.Measurement object at 0x7f9c70d454c0>
cat_log_softmax_split(theta, dim=-1)
  2.81 ms
  1 measurement, 100 runs , 1 thread

In [9]:
# Try the same thing with dim 0 instead of dim -1.

theta = torch.randn([num_edges, b]).to(device)

timer = benchmark.Timer(
    stmt='cat_log_softmax_split(theta, dim=0)',
    globals={
        'theta': theta,
        'cat_log_softmax_split': cat_log_softmax_split,
    })
timer.blocked_autorange()

<torch.utils.benchmark.utils.common.Measurement object at 0x7f9c70c7ac40>
cat_log_softmax_split(theta, dim=0)
  Median: 1.46 ms
  2 measurements, 100 runs per measurement, 1 thread

In [10]:
# Benchmark a full softmax for each internal node.
# (Before considering masked softmax.)

theta = torch.randn([b, num_edges]).to(device)
theta = theta.unsqueeze(-2).tile([1, num_internal, 1])
print(theta.shape)

timer = benchmark.Timer(
    stmt='torch.log_softmax(theta, dim=-1)',
    globals={'theta': theta})
timer.blocked_autorange()

torch.Size([1024, 372, 1371])


<torch.utils.benchmark.utils.common.Measurement object at 0x7f9d6f7c35b0>
torch.log_softmax(theta, dim=-1)
  Median: 6.91 ms
  3 measurements, 10 runs per measurement, 1 thread

In [11]:
# Try unique values rather than tiled.

theta = torch.randn([b, num_internal, num_edges]).to(device)
print(theta.shape)

timer = benchmark.Timer(
    stmt='torch.log_softmax(theta, dim=-1)',
    globals={'theta': theta})
timer.blocked_autorange()

torch.Size([1024, 372, 1371])


<torch.utils.benchmark.utils.common.Measurement object at 0x7f9c70c7a580>
torch.log_softmax(theta, dim=-1)
  Median: 6.93 ms
  3 measurements, 10 runs per measurement, 1 thread

In [12]:
# Try smaller softmax.

theta = torch.randn([b, num_internal, max_children]).to(device)
print(theta.shape)

timer = benchmark.Timer(
    stmt='torch.log_softmax(theta, dim=-1)',
    globals={'theta': theta})
timer.blocked_autorange()

torch.Size([1024, 372, 26])


<torch.utils.benchmark.utils.common.Measurement object at 0x7f9d8d06c970>
torch.log_softmax(theta, dim=-1)
  Median: 120.35 us
  2 measurements, 1000 runs per measurement, 1 thread

In [13]:
# OK, there is potential for a speed-up.
# Try constructing array for vectorized log-softmax.

row_index = np.concatenate([np.full(n, i) for i, n in enumerate(cond_sizes)])
col_index = np.concatenate([np.arange(n) for n in cond_sizes])
flat_index = row_index * max_children + col_index
flat_index = torch.from_numpy(flat_index).to(device)

def split_scores(theta, dim=-1):
    assert dim == -1
    input_shape = list(theta.shape)
    flat_shape = [*input_shape[:-1], num_internal * max_children]
    flat = torch.full(flat_shape, -torch.inf, device=theta.device)
    flat.index_copy_(dim, flat_index, theta)
    split_shape = [*input_shape[:-1], num_internal, max_children]
    return flat.reshape(split_shape)

theta = torch.randn([b, num_edges], device=device)

timer = benchmark.Timer(
    stmt='torch.log_softmax(split_scores(theta, dim=-1), dim=-1)',
    globals={
        'theta': theta,
        'split_scores': split_scores,
    })
timer.blocked_autorange()

<torch.utils.benchmark.utils.common.Measurement object at 0x7f9c70d2c340>
torch.log_softmax(split_scores(theta, dim=-1), dim=-1)
  222.05 us
  1 measurement, 1000 runs , 1 thread

In [14]:
# Put it together.
# Split scores, take log-softmax and re-assemble.

def hier_log_softmax(theta, dim=-1):
    assert dim == -1
    input_shape = list(theta.shape)
    flat_shape = [*input_shape[:-1], num_internal * max_children]

    child_theta = split_scores(theta, dim=-1)
    child_logp = torch.log_softmax(child_theta, dim=-1)
    child_logp = child_logp.reshape(flat_shape)
    logp = child_logp.index_select(-1, flat_index)
    # Add a zero for the root node.
    zero = torch.zeros([*input_shape[:-1], 1], device=device)
    return torch.cat([zero, logp], dim=-1)


timer = benchmark.Timer(
    stmt='hier_log_softmax(theta, dim=-1)',
    globals={
        'theta': theta,
        'hier_log_softmax': hier_log_softmax,
    })
timer.blocked_autorange()

<torch.utils.benchmark.utils.common.Measurement object at 0x7f9c70c7aa90>
hier_log_softmax(theta, dim=-1)
  278.44 us
  1 measurement, 1000 runs , 1 thread