In [3]:
import torch
import time
print("Is torch.curda availabie ? ", torch.cuda.is_available())
print("Torch cuda device name  : ",torch.cuda.get_device_name(0))
print("Pytorch device number : ", torch.cuda.device_count())


Is torch.curda availabie ?  True
Torch cuda device name  :  GeForce RTX 3090
Pytorch device number :  2


In [5]:
x = torch.cuda.FloatTensor(10000, 500).normal_()
w = torch.cuda.FloatTensor(200, 500).normal_()

# ensure that context initialization and normal_() operations
# finish before you start measuring time
torch.cuda.synchronize()
torch.cuda.synchronize()

a = time.perf_counter()
y = x.mm(w.t())
torch.cuda.synchronize() # wait for mm to finish
b = time.perf_counter()
print('batch GPU {:.02e}s'.format(b - a))

a = time.perf_counter()
y = x.mm(w.t())
torch.cuda.synchronize() # wait for mm to finish
b = time.perf_counter()
print('batch GPU {:.02e}s'.format(b - a))

batch GPU 2.25e-04s
batch GPU 2.71e-04s


In [10]:
import torch.utils.benchmark as benchmark
def batched_dot_mul_sum(a, b):
    '''Computes batched dot by multiplying and summing'''
    return a.mul(b).sum(-1)


def batched_dot_bmm(a, b):
    '''Computes batched dot by reducing to bmm'''
    a = a.reshape(-1, 1, a.shape[-1])
    b = b.reshape(-1, b.shape[-1], 1)
    return torch.bmm(a, b).flatten(-3)


# Input for benchmarking
x = torch.randn(10000, 64)

# Ensure that both functions compute the same output
assert batched_dot_mul_sum(x, x).allclose(batched_dot_bmm(x, x))

t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

print(t0.timeit(100))
print(t1.timeit(100))


<torch.utils.benchmark.utils.common.Measurement object at 0x7f933cbbbdf0>
batched_dot_mul_sum(x, x)
  105.06 us
  1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f94007a87f0>
batched_dot_bmm(x, x)
  333.84 us
  1 measurement, 100 runs , 1 thread


In [12]:
x = torch.randn(10000, 1024, device='cuda')
import timeit
t0 = timeit.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = timeit.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

# Ran each twice to show difference before/after warmup
print(f'mul_sum(x, x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'mul_sum(x, x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')


mul_sum(x, x):   17.8 us
mul_sum(x, x):   10.9 us
bmm(x, x):       19.8 us
bmm(x, x):       15.4 us
