## Benchmark with `torch.utils.benchmark`

Pytorch provides way to benchmark `nn.Module` speed, which is needed in very various circumstances. In this notebook, we will see how `benchmark` works and how compare different approaches with multiple runs.

### 0. Import Modules

In [1]:
import torch

### 1. Defining functions to benchmark

In [4]:
def batched_dot_mul_sum(a, b):
    return a.mul(b).sum(-1)

def batched_dot_bmm(a, b):
    a = a.unsqueeze(1)
    b = b.unsqueeze(-1)
    return a.bmm(b).flatten(-3)

x = torch.rand([1000, 64])

assert batched_dot_mul_sum(x, x).allclose(batched_dot_bmm(x, x))

### 2. Benchmark with timeit

In [21]:
import timeit

t0 = timeit.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x':x}
)

t1 = timeit.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x':x}
)

print(f'dot_mul: {t0.timeit(1000) / 1000 * 1e6:>5.1f}us')
print(f'dot_bmm: {t1.timeit(1000) / 1000 * 1e6:>5.1f}us')

dot_mul: 194.8us
dot_bmm:  28.1us


### 3. Benchmark with torch.utils.benchmark

The things you should pay attentions:
1. Num threads
2. GPU or CPU

What `torch.utils.benchmark` cares for you:
- warmup
- gpu & cpu sych
- (`blocked_autorange`) min num runs

In [29]:
import torch.utils.benchmark as benchmark

t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x':x}
)

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x':x}
)

print(t0.timeit(100))
print(t1.timeit(100))

<torch.utils.benchmark.utils.common.Measurement object at 0x7f940fa3b5d0>
batched_dot_mul_sum(x, x)
setup: from __main__ import batched_dot_mul_sum
  46.98 us
  1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f94172bad90>
batched_dot_bmm(x, x)
setup: from __main__ import batched_dot_bmm
  133.97 us
  1 measurement, 100 runs , 1 thread


In [32]:
import torch
import torch.utils.benchmark as benchmark

num_threads = torch.get_num_threads()

print(f'benchmarking on {num_threads} threads')

t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x':x},
    num_threads=num_threads,
    label='Batched Dot',
    sub_label='Impl. with mul'
)

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x':x},
    num_threads=num_threads,
    label='Batched Dot',
    sub_label='Impl. with bmm'
)

print(t0.timeit(100))
print(t1.timeit(100))


benchmarking on 16 threads
<torch.utils.benchmark.utils.common.Measurement object at 0x7f940ee037d0>
Batched Dot: Impl. with mul
setup: from __main__ import batched_dot_mul_sum
  54.37 us
  1 measurement, 100 runs , 16 threads
<torch.utils.benchmark.utils.common.Measurement object at 0x7f940fa3b850>
Batched Dot: Impl. with bmm
setup: from __main__ import batched_dot_bmm
  333.63 us
  1 measurement, 100 runs , 16 threads


In [33]:
print(t0.blocked_autorange())
print(t1.blocked_autorange())

<torch.utils.benchmark.utils.common.Measurement object at 0x7f940fa362d0>
Batched Dot: Impl. with mul
setup: from __main__ import batched_dot_mul_sum
  139.34 us
  1 measurement, 10000 runs , 16 threads
<torch.utils.benchmark.utils.common.Measurement object at 0x7f940edef710>
Batched Dot: Impl. with bmm
setup: from __main__ import batched_dot_bmm
  104.36 us
  1 measurement, 10000 runs , 16 threads


### 4. Run benchmark with different input & condition

In [9]:
import torch
import torch.utils.benchmark as benchmark
from itertools import product

results = []

size = [1, 64, 256, 1024]

for b, n in product(size, size):
    label = 'matrix multiplication'
    sub_label = f'[{b}, {n}]'
    x = torch.rand((b, n))

    for num_threads in [1, 4, 8, 16]:
        results.append(benchmark.Timer(
            stmt='batched_dot_mul_sum(x, x)',
            setup='from __main__ import batched_dot_mul_sum',
            globals={'x': x},
            label=label,
            sub_label=sub_label,
            description='mul',
            num_threads=num_threads
        ).blocked_autorange(min_run_time=1))

        results.append(benchmark.Timer(
            stmt='batched_dot_bmm(x, x)',
            setup='from __main__ import batched_dot_bmm',
            globals={'x': x},
            label=label,
            sub_label=sub_label,
            description='bmm',
            num_threads=num_threads
        ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.print()

[------- matrix multiplication -------]
                    |   mul   |   bmm  
1 threads: ----------------------------
      [1, 1]        |    4.5  |     5.8
      [1, 64]       |    4.8  |     6.1
      [1, 256]      |    6.2  |     6.3
      [1, 1024]     |    6.2  |     8.9
      [64, 1]       |    4.8  |     6.9
      [64, 64]      |   13.5  |    15.6
      [64, 256]     |   12.6  |    24.8
      [64, 1024]    |   25.0  |   186.3
      [256, 1]      |    6.4  |     8.6
      [256, 64]     |   14.4  |    23.2
      [256, 256]    |   26.3  |    84.6
      [256, 1024]   |  142.3  |   726.3
      [1024, 1]     |   12.0  |    14.3
      [1024, 64]    |   33.6  |    76.4
      [1024, 256]   |  147.5  |   304.8
      [1024, 1024]  |  661.3  |  2862.0
4 threads: ----------------------------
      [1, 1]        |    7.6  |     5.8
      [1, 64]       |    5.2  |     5.8
      [1, 256]      |    6.2  |    11.0
      [1, 1024]     |    6.0  |     8.2
      [64, 1]       |    4.8  |     8.3
