In [None]:
from itertools import product

import torch
import torch.utils.benchmark as benchmark

# Compare takes a list of measurements which we'll save in results.
results = []


n = 128
input = torch.randn(24, 2, 3, 384, 384)


def eye_like_repeat() -> torch.Tensor:
    if n <= 0:
        raise AssertionError(type(n), n)
    if len(input.shape) < 1:
        raise AssertionError(input.shape)

    identity = torch.eye(n, device=input.device, dtype=input.dtype)
    return identity[None].repeat(input.shape[0], 1, 1)


def eye_like_expand() -> torch.Tensor:
    if n <= 0:
        raise AssertionError(type(n), n)
    if len(input.shape) < 1:
        raise AssertionError(input.shape)

    identity = torch.eye(n, device=input.device, dtype=input.dtype)
    return identity[None].expand(input.shape[0], n, n).clone()


fcn_names = [
    "eye_like_repeat", 
    "eye_like_expand"
]

for fcn_name in fcn_names:
    # label and sub_label are the rows
    # description is the column
    results.append(
        benchmark.Timer(
            stmt=f'{fcn_name}()',
            setup=f'from __main__ import {fcn_name}',
            description='test_eye_like',
        ).blocked_autorange(min_run_time=1)
    )

compare = benchmark.Compare(results)
compare.print()

[-------------------  ------------------]
                         |  test_eye_like
1 threads: ------------------------------
      eye_like_repeat()  |      251.9    
      eye_like_expand()  |      263.3    

Times are in microseconds (us).

