In [1]:
import torch
import interpol
import jitfields
from torch.nn import functional as F

In [2]:
ndim = 3
shape = [128] * ndim
img = torch.stack(torch.meshgrid(*[torch.arange(s).float() for s in shape]), -1)
img -= (torch.as_tensor(shape).float() - 1) / 2
img = img.square().sum(-1).sqrt()
img = (img < 48).float()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [3]:
cshape = [12] * ndim   # number of control points
scale = 2              # standard deviation of random displacement size
disp = torch.randn([*cshape, ndim]) * scale

# interpol.resize expects the number of channels to be first, so we move
# it around
disp = disp.movedim(-1, 0)
disp = interpol.resize(disp, shape=shape, interpolation=3)
disp = disp.movedim(0, -1)

# convert the *displacement* field into a *sampling* field
identity = torch.stack(torch.meshgrid(*[torch.arange(s).float() for s in shape]), -1)
grid = identity + disp

In [7]:
order = 1

def wrp_jit(inp, grid):
    out = jitfields.pull(inp.unsqueeze(-1), grid, order=order).squeeze(-1)
    if inp.is_cuda:
        torch.cuda.synchronize(inp.device)
    return out

def psh_jit(inp, grid):
    out = jitfields.push(inp.unsqueeze(-1), grid, order=order).squeeze(-1)
    if inp.is_cuda:
        torch.cuda.synchronize(inp.device)
    return out

def wrp_ts(inp, grid):
    out = interpol.grid_pull(inp, grid, interpolation=order)
    if inp.is_cuda:
        torch.cuda.synchronize(inp.device)
    return out

def psh_ts(inp, grid):
    out = interpol.grid_push(inp, grid, interpolation=order)
    if inp.is_cuda:
        torch.cuda.synchronize(inp.device)
    return out

def wrp_torch(inp, grid):
    mode = 'bilinear'
    grid = grid.flip(-1)
    grid[..., 0].add_(0.5).div_(grid.shape[1]/2).sub_(1)
    grid[..., 1].add_(0.5).div_(grid.shape[0]/2).sub_(1)
    out = F.grid_sample(inp[None, None], grid[None], mode=mode, align_corners=False)[0, 0]
    if inp.is_cuda:
        torch.cuda.synchronize(inp.device)
    return out

In [5]:
torch.set_num_threads(1)
jitfields.set_num_threads(1)

1

In [8]:
device = 'cpu'
img = img.to(device)
grid = grid.to(device)

# compile kernels
wrp = wrp_ts(img, grid)
psh = psh_ts(wrp, grid)
wrp = wrp_jit(img, grid)
psh = psh_jit(wrp, grid)

# time pull
%timeit wrp_jit(img, grid)
%timeit wrp_ts(img, grid)
%timeit wrp_torch(img, grid)

# time push
%timeit psh_jit(wrp, grid)
%timeit psh_ts(wrp, grid)


438 ms ± 4.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
392 ms ± 30.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
93.1 ms ± 1.32 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
521 ms ± 4.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
439 ms ± 9.65 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
device = 'cuda'
img = img.to(device)
grid = grid.to(device)

# compile kernels
wrp = wrp_ts(img, grid)
psh = psh_ts(wrp, grid)
wrp = wrp_jit(img, grid)
psh = psh_jit(wrp, grid)

# time pull
%timeit wrp_jit(img, grid)
%timeit wrp_ts(img, grid)
%timeit wrp_torch(img, grid)

# time push
%timeit psh_jit(wrp, grid)
%timeit psh_ts(wrp, grid)

810 µs ± 872 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
5.4 ms ± 87.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
789 µs ± 906 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.07 ms ± 2.93 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
5.91 ms ± 79.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
