In [1]:
# Following this guid to benchmark PyTorch operations: https://pytorch.org/tutorials/recipes/recipes/benchmark.html#benchmarking-with-torch-utils-benchmark-timer

import torch
import torch.utils.benchmark as benchmark


In [None]:
# define the functions to compare/benchmark/time
def index_using_gather(tensor, indices):
    """Selects elements from a tensor using gather (for N, 1)."""
    return torch.gather(tensor, dim=0, index=indices.unsqueeze(1))  # Since dim is 1

def index_using_integral_indexing(tensor, indices):
    """Selects elements from a tensor using integer indexing (for N, 1)."""
    return tensor[indices]  # Direct indexing on the first dimension

# Sample tensor and indices
tensor = torch.randn(1000, 1)
indices = torch.randint(0, tensor.shape[0], (100, ))  # Generate random indices for N


In [4]:
# Benchmarking with pytorch.utils.benchmark
t_gather = benchmark.Timer(
    stmt="index_using_gather(tensor.clone(), indices.clone())",
    setup="from __main__ import index_using_gather, tensor, indices",
)
t_indexing = benchmark.Timer(
    stmt="index_using_integral_indexing(tensor.clone(), indices.clone())",
    setup="from __main__ import index_using_integral_indexing, tensor, indices",
)

# Repeatedly run the timers for more accurate measurements
print("Gather:")
print(t_gather.timeit(number=1000))  # Run 1000 times for better accuracy
print("Integer Indexing:")
print(t_indexing.timeit(number=1000))

# Ensure outputs are the same
assert torch.allclose(index_using_gather(tensor.clone(), indices.clone()), index_using_integral_indexing(tensor.clone(), indices.clone()))


Gather:
<torch.utils.benchmark.utils.common.Measurement object at 0x7f5b4b215550>
index_using_gather(tensor.clone(), indices.clone())
setup: from __main__ import index_using_gather, tensor, indices
  6.72 us
  1 measurement, 1000 runs , 1 thread
Integer Indexing:
<torch.utils.benchmark.utils.common.Measurement object at 0x7f5a85f6c490>
index_using_integral_indexing(tensor.clone(), indices.clone())
setup: from __main__ import index_using_integral_indexing, tensor, indices
  6.89 us
  1 measurement, 1000 runs , 1 thread
