In [1]:
from functools import partial
import pathlib

import numpy as np
import torch
from torch import nn
from torch.utils import benchmark

In [2]:
import sys

# TODO: pip install -e hier
PROJECT_ROOT = pathlib.Path('.').absolute().parent
sys.path.insert(0, str(PROJECT_ROOT))

import hier

In [3]:
RESOURCE_DIR = pathlib.Path('../resources')
with open(RESOURCE_DIR / 'hierarchy/imagenet_fiveai.csv') as f:
    edges = hier.load_edges(f)
tree, names = hier.make_hierarchy_from_edges(edges)

In [4]:
tree.num_leaf_nodes(), tree.num_nodes()

(1000, 1372)

In [None]:
b = 1024
d = tree.num_leaf_nodes()
device = torch.device('cuda')

theta = torch.randn([b, d]).to(device)

In [None]:
def sum_leaf_descendants(
        tree: hier.Hierarchy,
        values: torch.Tensor,
        dim: int = -1) -> torch.Tensor:
    """Computes sum over leaf descendants for each node."""
    # The value is_ancestor[i, j] is true if i is an ancestor of j.
    is_ancestor = tree.ancestor_mask()
    leaf_is_descendant = is_ancestor[:, tree.leaf_mask()].T
    matrix = torch.from_numpy(leaf_is_descendant)
    matrix = matrix.to(device=values.device, dtype=torch.get_default_dtype())
    # TODO: Re-order dimensions to make this work with dim != -1.
    assert dim == -1 or dim == values.ndim - 1
    return torch.tensordot(values, matrix, dims=1)

In [None]:
class Sum(nn.Module):
    """Avoids re-computation in sum_xxx()."""

    def __init__(
            self,
            tree: hier.Hierarchy,
            leaf: bool,
            transpose: bool,
            strict: bool = False):
        super().__init__()

        # The value is_ancestor[i, j] is true if i is an ancestor of j.
        matrix = tree.ancestor_mask(strict=strict)
        if leaf:
            matrix = matrix[:, tree.leaf_mask()]
        if transpose:
            matrix = matrix.T
        matrix = torch.from_numpy(matrix).type(torch.get_default_dtype())

        # self.matrix = matrix
        self.register_buffer('matrix', matrix)
        self.matrix: Optional[torch.Tensor]

    def forward(self, values: torch.Tensor, dim: int = -1) -> torch.Tensor:
        # TODO: Re-order dimensions to make this work with dim != -1.
        assert dim == -1 or dim == values.ndim - 1
        return torch.tensordot(values, self.matrix, dims=1)

In [None]:
SumAncestors = partial(Sum, leaf=False, transpose=False)
SumLeafAncestors = partial(Sum, leaf=False, transpose=True)
SumDescendants = partial(Sum, leaf=False, transpose=True)
SumLeafDescendants = partial(Sum, leaf=True, transpose=True)

In [None]:
# def sum_leaf_descendants_sparse(
#         tree: hier.Hierarchy,
#         values: torch.Tensor,
#         dim: int = -1) -> torch.Tensor:
#     """Computes sum over leaf descendants for each node."""
#     # The value is_ancestor[i, j] is true if i is an ancestor of j.
#     is_ancestor = tree.ancestor_mask()
#     leaf_is_descendant = is_ancestor[:, tree.leaf_mask()].T
#     matrix = torch.from_numpy(leaf_is_descendant)
#     matrix = matrix.to(device=values.device, dtype=torch.get_default_dtype())
#     # TODO: Re-order dimensions to make this work with dim != -1.
#     assert dim == -1 or dim == values.ndim - 1
#     return torch.tensordot(values, matrix, dims=1)

In [None]:
t_dense = benchmark.Timer(
    stmt='sum_leaf_descendants(tree, theta)',
    setup='from __main__ import sum_leaf_descendants',
    globals={
        'tree': tree,
        'theta': theta,
    })

t_dense_pre = benchmark.Timer(
    stmt='sum_fn(theta)',
    globals={
        'sum_fn': SumLeafDescendants(tree).to(device),
        'theta': theta,
    })

In [None]:
t_dense.blocked_autorange()

In [None]:
t_dense_pre.blocked_autorange()

In [None]:
matrix = tree.ancestor_mt_dense.blocked_autorange()ask()

In [None]:
row_subset = [np.flatnonzero(row) for row in matrix]

In [None]:
row_subset

In [None]:
b = 1024
d = tree.num_nodes() - 1
device = torch.device('cuda')

theta = torch.randn([b, d]).to(device)

In [None]:
def hier_log_softmax(
        tree: hier.Hierarchy,
        scores: torch.Tensor,
        dim: int = -1) -> torch.Tensor:
    internal_nodes = tree.internal_subset()
    node_to_children = tree.children()
    cond_children = [node_to_children[x] for x in internal_nodes]
    cond_sizes = [len(x) for x in cond_children]
    cond_scores = scores.split(cond_sizes, dim=dim)
    cond_log_softmax = [x.log_softmax(dim=dim) for x in cond_scores]
    shape = list(scores.shape)
    shape[dim] = tree.num_nodes()
    log_cond_prob = torch.zeros(shape, device=scores.device).index_add(
        dim,
        torch.from_numpy(np.concatenate(cond_children)).to(scores.device),
        torch.cat(cond_log_softmax, dim=dim))
    log_prob = sum_ancestors(tree, log_cond_prob, dim=dim, strict=False)
    return log_prob


def sum_ancestors(
        tree: hier.Hierarchy,
        values: torch.Tensor,
        dim: int = -1,
        strict: bool = False) -> torch.Tensor:
    """Computes sum over ancestors of each node."""
    # The value is_ancestor[i, j] is true if i is an ancestor of j.
    is_ancestor = tree.ancestor_mask(strict=strict)
    matrix = (torch.from_numpy(is_ancestor)
              .to(device=values.device, dtype=torch.get_default_dtype()))
    # TODO: Re-order dimensions to make this work with dim != -1.
    assert dim == -1 or dim == values.ndim - 1
    return torch.tensordot(values, matrix, dims=1)

In [None]:
class HierLogSoftmax(nn.Module):
    """Avoids re-computation in hier_log_softmax()."""

    def __init__(self, tree: hier.Hierarchy):
        super().__init__()
        internal_nodes = tree.internal_subset()
        node_to_children = tree.children()
        cond_children = [node_to_children[x] for x in internal_nodes]
        cond_sizes = [len(x) for x in cond_children]
        cat_cond_children = torch.from_numpy(np.concatenate(cond_children))

        self.cond_sizes = cond_sizes
        self.num_nodes = tree.num_nodes()
        # self.cat_cond_children = cat_cond_children
        self.register_buffer('cat_cond_children', cat_cond_children)
        self.cat_cond_children: Optional[torch.Tensor]
        self.sum_ancestors = SumAncestors(tree, strict=False)

    def forward(self, scores: torch.Tensor, dim: int = -1) -> torch.Tensor:
        device = scores.device
        cond_scores = scores.split(self.cond_sizes, dim=dim)
        cond_log_softmax = [x.log_softmax(dim=dim) for x in cond_scores]
        shape = list(scores.shape)
        shape[dim] = self.num_nodes
        log_cond_prob = torch.zeros(shape, device=device).index_add(
            dim, self.cat_cond_children, torch.cat(cond_log_softmax, dim=dim))
        log_prob = self.sum_ancestors(log_cond_prob, dim=dim)
        return log_prob

In [None]:
t_baseline = benchmark.Timer(
    stmt='theta.log_softmax(dim=-1)',
    globals={
        'theta': theta,
    })

t_dense = benchmark.Timer(
    stmt='hier_log_softmax(tree, theta)',
    setup='from __main__ import hier_log_softmax',
    globals={
        'tree': tree,
        'theta': theta,
    })

t_dense_pre = benchmark.Timer(
    stmt='hier_log_softmax_fn(theta)',
    globals={
        'hier_log_softmax_fn': HierLogSoftmax(tree).to(device),
        'theta': theta,
    })

In [None]:
t_baseline.blocked_autorange()

In [None]:
t_dense.blocked_autorange()

In [None]:
t_dense_pre.blocked_autorange()