In [1]:
import torch
import interpol
import jitfields

torch.set_num_threads(1)

In [2]:
ndim = 3
shape = [128] * ndim
img = torch.stack(torch.meshgrid(*[torch.arange(s).float() for s in shape]), -1)
img -= (torch.as_tensor(shape).float() - 1) / 2
img = img.square().sum(-1).sqrt()
img = (img < 48).float()

In [3]:
cshape = [12] * ndim   # number of control points
scale = 2              # standard deviation of random displacement size
disp = torch.randn([*cshape, ndim]) * scale

# interpol.resize expects the number of channels to be first, so we move
# it around
disp = disp.movedim(-1, 0)
disp = interpol.resize(disp, shape=shape, interpolation=3)
disp = disp.movedim(0, -1)

# convert the *displacement* field into a *sampling* field
identity = torch.stack(torch.meshgrid(*[torch.arange(s).float() for s in shape]), -1)
grid = identity + disp

In [4]:
order = 1
torch.set_num_threads(1)

# compile kernels
wrp_jit = jitfields.pull(img.unsqueeze(-1), grid, order=order).squeeze(-1)
psh_jit = jitfields.push(wrp_jit.unsqueeze(-1), grid, order=order).squeeze(-1)

# time pull
%timeit jitfields.pull(img.unsqueeze(-1), grid, order=order).squeeze(-1)
%timeit interpol.grid_pull(img, grid, interpolation=order)

# time push
%timeit jitfields.push(wrp_jit.unsqueeze(-1), grid, order=order).squeeze(-1)
%timeit interpol.grid_push(wrp_jit, grid, interpolation=order)


503 ms ± 6.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
504 ms ± 117 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
585 ms ± 43.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
488 ms ± 120 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
img = img.cuda()
grid = grid.cuda()

# compile kernels
wrp_jit = jitfields.pull(img.unsqueeze(-1), grid, order=order).squeeze(-1)
psh_jit = jitfields.push(wrp_jit.unsqueeze(-1), grid, order=order).squeeze(-1)

# time pull
%timeit jitfields.pull(img.unsqueeze(-1), grid, order=order).squeeze(-1)
%timeit interpol.grid_pull(img, grid, interpolation=order)

# time push
%timeit jitfields.push(wrp_jit.unsqueeze(-1), grid, order=order).squeeze(-1)
%timeit interpol.grid_push(wrp_jit, grid, interpolation=order)

691 µs ± 535 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
7.43 ms ± 556 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
877 µs ± 701 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
7.52 ms ± 525 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
