In [1]:
import torch
import numpy as np

print(f"Torch Version: {torch.__version__}")
print(f"GPU   Type   : {torch.cuda.get_device_name(0)}")

Torch Version: 1.4.0+cu100
GPU   Type   : Quadro RTX 5000


## Various custom implementations of slow PyTorch functions to compare against

In [2]:
def _colwise_max(input_tensor: torch.Tensor, 
                dim: int):
    """Perform column-wise max operation. Same as input_tensor.max(dim)[0].
    
    Orders of magnitude faster. Known bug in PyTorch.
    """
    ndim = len(input_tensor.size())
    if ndim == 1:
        input_tensor = input_tensor.view(len(input_tensor), 1)
    out = torch.stack(
        [torch.max(input_tensor[:, i]) for i in range(input_tensor.shape[1 - dim])]
    )
    return out


def _fast_bincount(input_tensor):
    """A faster version of bincount than torch.bincount. Same API."""
    ndim = len(input_tensor.size())
    if not ndim == 1:
        raise ValueError(f"input_tensor must be 1-d. Instead it is {ndim}-d.")
    if not (input_tensor[1:] >= input_tensor[:-1]).all():
        input_tensor = torch.sort(input_tensor)[0]
    diff_from_prev = input_tensor[1:] != input_tensor[:-1]
    first_ind = torch.nonzero(diff_from_prev)[:, 0] + 1
    inds = torch.zeros(len(first_ind) + 2, device=input_tensor.device, dtype=torch.long)
    inds[-1] = len(input_tensor)
    inds[1:-1] = first_ind
    bincounts = inds[1:] - inds[:-1]
    return bincounts

to_np = lambda x: x.detach().cpu().numpy()

## Timing for column-wise max operation

In [3]:
# create random input tensor
input_tensor = torch.rand([100000, 4]).cuda()

In [4]:
%%timeit
colmax = input_tensor.max(dim=0)[0]
torch.cuda.synchronize()

3.89 ms ± 421 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [5]:
%%timeit
colmax_custom = _colwise_max(input_tensor, dim=0)
torch.cuda.synchronize()

171 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [6]:
colmax = input_tensor.max(dim=0)[0]
colmax_custom = _colwise_max(input_tensor, dim=0)
np.testing.assert_equal(to_np(colmax), to_np(colmax_custom))

## Timing for bincount

In [7]:
# create random input tensor
input_tensor = torch.randint(20, size=[100000]).cuda()
input_tensor = torch.sort(input_tensor)[0]  # for a pre-sorted input tensor, _fast_bincount is much faster

In [8]:
%%timeit
counts = torch.bincount(input_tensor)
torch.cuda.synchronize()

1.18 ms ± 717 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [9]:
%%timeit
counts_custom = _fast_bincount(input_tensor)
torch.cuda.synchronize()

189 µs ± 1.81 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [10]:
# check that results are the same
counts = torch.bincount(input_tensor)
counts_custom = _fast_bincount(torch.sort(input_tensor)[0])
np.testing.assert_equal(to_np(counts), to_np(counts_custom))