In [175]:
%load_ext autoreload
%autoreload 2

import torch
import speckleret.torch as tspr

import matplotlib.pyplot as plt

In [186]:
from contextlib import contextmanager
import time

@contextmanager
def timing(device: str = "cpu"):
    if device.startswith("cuda"):
        torch.cuda.synchronize()
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        start_event.record()
        yield lambda: start_event, end_event
        end_event.record()
        torch.cuda.synchronize()
        elapsed_time_ms = start_event.elapsed_time(end_event)
        print(f"Total time: {elapsed_time_ms/1000:.4f} seconds")
        print(f"Average time per loop: {elapsed_time_ms / n_loops:.4f} ms")
    else:
        start_time = time.time()
        yield lambda: None  # unused for CPU
        elapsed_time = time.time() - start_time
        print(f"Total time: {elapsed_time:.4f} seconds")
        print(f"Average time per loop: {elapsed_time / n_loops:.6f} seconds")



device = 'cuda'  # or 'cpu'
# device = 'cuda'  # or 'cpu'
n = 200
N = 100
C = 10
n_loops = 20

x = torch.randn((N, C, n, n), dtype=torch.cfloat, device=device)
_ = torch.fft.fft2(x)
torch.cuda.synchronize()

with timing(device):
    for _ in range(n_loops):
        x = torch.randn(size=(N, C, n, n), dtype=torch.cfloat, device=device)
        y = tspr.transforms.fourier_transform(x)
        # y = tspr.transforms.fourier_transform_n(x)





Total time: 0.8157 seconds
Average time per loop: 40.7848 ms


In [None]:
()