In [1]:
import torch
import interpol
import jitfields

In [2]:
ndim = 3
shape = [128] * ndim
img = torch.stack(torch.meshgrid(*[torch.arange(s).float() for s in shape]), -1)
img -= (torch.as_tensor(shape).float() - 1) / 2
img = img.square().sum(-1).sqrt()
img = (img < 48).float()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [3]:
cshape = [12] * ndim   # number of control points
scale = 2              # standard deviation of random displacement size
disp = torch.randn([*cshape, ndim]) * scale

# interpol.resize expects the number of channels to be first, so we move
# it around
disp = disp.movedim(-1, 0)
disp = interpol.resize(disp, shape=shape, interpolation=3)
disp = disp.movedim(0, -1)

# convert the *displacement* field into a *sampling* field
identity = torch.stack(torch.meshgrid(*[torch.arange(s).float() for s in shape]), -1)
grid = identity + disp

In [10]:
order = 2

def wrp_jit(inp, grid):
    out = jitfields.pull(inp.unsqueeze(-1), grid, order=order).squeeze(-1)
    if inp.is_cuda:
        torch.cuda.synchronize(inp.device)
    return out

def psh_jit(inp, grid):
    out = jitfields.push(inp.unsqueeze(-1), grid, order=order).squeeze(-1)
    if inp.is_cuda:
        torch.cuda.synchronize(inp.device)
    return out

def wrp_tch(inp, grid):
    out = interpol.grid_pull(inp, grid, interpolation=order)
    if inp.is_cuda:
        torch.cuda.synchronize(inp.device)
    return out

def psh_tch(inp, grid):
    out = interpol.grid_push(inp, grid, interpolation=order)
    if inp.is_cuda:
        torch.cuda.synchronize(inp.device)
    return out

In [11]:
torch.set_num_threads(1)
jitfields.set_num_threads(1)

1

In [12]:
device = 'cpu'
img = img.to(device)
grid = grid.to(device)

# compile kernels
wrp = wrp_tch(img, grid)
psh = psh_tch(wrp, grid)
wrp = wrp_jit(img, grid)
psh = psh_jit(wrp, grid)

# time pull
%timeit wrp_jit(img, grid)
%timeit wrp_tch(img, grid)

# time push
%timeit psh_jit(wrp, grid)
%timeit psh_tch(wrp, grid)


679 ms ± 4.34 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.27 s ± 7.43 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
916 ms ± 8.76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.15 s ± 31.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
device = 'cuda'
img = img.to(device)
grid = grid.to(device)

# compile kernels
wrp = wrp_tch(img, grid)
psh = psh_tch(wrp, grid)
wrp = wrp_jit(img, grid)
psh = psh_jit(wrp, grid)

# time pull
%timeit wrp_jit(img, grid)
%timeit wrp_tch(img, grid)

# time push
%timeit psh_jit(wrp, grid)
%timeit psh_tch(wrp, grid)

1.38 ms ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
24.3 ms ± 8.41 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
1.92 ms ± 1.31 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
22.2 ms ± 3.33 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
