In [2]:
import time
import torch

def test_latency(fn, n_warmup, n_iters, device, *args, **kwargs):
    results = {}
    # Warmup
    for _ in range(n_warmup):
        fn(*args, **kwargs)

    if device is not None and device.type == "cuda":
        torch.cuda.synchronize()
        start_time = time.time()
        try:
            for _ in range(n_iters):
                if hasattr(fn, "train_example"):
                    fn.train_example(*args, **kwargs)
                else:
                    fn(*args, **kwargs)
        except Exception as e:
            start_time = float("inf")
            print(e)
        torch.cuda.synchronize()
        end_time = time.time()
        results["latency"] = (end_time - start_time) / n_iters  # seconds
    else:
        start_time = time.time()
        for _ in range(n_iters):
            if hasattr(fn, "train_example"):
                fn.train_example(*args, **kwargs)
            else:
                fn(*args, **kwargs)
        end_time = time.time()
        results["latency"] = (end_time - start_time) / n_iters  # seconds

    return results