In [36]:
import torch
torch.set_printoptions(sci_mode=False)

In [37]:
image = torch.rand((5,5,1))
image3c = torch.rand((5,5,3))
x, y = torch.tensor(0.5), torch.tensor(0.5)
x_b, y_b = torch.tensor(0.5).repeat(16).reshape(16,1), torch.tensor(0.5).repeat(16).reshape(16,1)

In [38]:
def compare():
    bi = torch.Tensor(tuple(bilinear_interpolate(image3c[:,:,c], x, y).item() for c in range(3))).repeat(16)
    bmm = matrix_interpolation_3cb(image3c, x_b, y_b).flatten()
    return torch.isclose(bi,bmm).all().item()

In [39]:
def bilinear_interpolate(im, x, y, dtype=torch.FloatTensor, dtype_long=torch.LongTensor):
    # Get four corner indicies
    x0 = torch.floor(x).type(dtype_long)
    x1 = x0 + 1
    y0 = torch.floor(y).type(dtype_long)
    y1 = y0 + 1
    # Clamp within h, w boundries
    x0 = torch.clamp(x0, 0, im.shape[1]-1)
    x1 = torch.clamp(x1, 0, im.shape[1]-1)
    y0 = torch.clamp(y0, 0, im.shape[0]-1)
    y1 = torch.clamp(y1, 0, im.shape[0]-1)
    # Get four corner pixel values
    Ia = im[x0, y0]
    Ib = im[x0, y1]
    Ic = im[x1, y0]
    Id = im[x1, y1]
    # Weight four corner pixel values
    wa = (x1.type(dtype)-x) * (y1.type(dtype)-y)
    wc = (x1.type(dtype)-x) * (y-y0.type(dtype))
    wb = (x-x0.type(dtype)) * (y1.type(dtype)-y)
    wd = (x-x0.type(dtype)) * (y-y0.type(dtype))
    return torch.t((torch.t(Ia)*wa)) + torch.t(torch.t(Ib)*wb) + torch.t(torch.t(Ic)*wc) + torch.t(torch.t(Id)*wd)

In [247]:
def matrix_interpolation(im, x, y, dtype=torch.FloatTensor, dtype_long=torch.LongTensor):
    # Get four corner indicies
    x0 = torch.floor(x).type(dtype_long)
    x1 = x0 + 1
    y0 = torch.floor(y).type(dtype_long)
    y1 = y0 + 1
    # Clamp within h, w boundries
    x0 = torch.clamp(x0, 0, im.shape[1]-1)
    x1 = torch.clamp(x1, 0, im.shape[1]-1)
    y0 = torch.clamp(y0, 0, im.shape[0]-1)
    y1 = torch.clamp(y1, 0, im.shape[0]-1)
    # Get four corner pixel values
    Ia = im[x0, y0]
    Ib = im[x0, y1]
    Ic = im[x1, y0]
    Id = im[x1, y1]
    # Define matricies
    scale = 1 / ( (x1-x0) * (y1-y0) )
    m1 = torch.Tensor([x1-x, x-x0])
    m2 = torch.Tensor([
        [Ib, Ia],
        [Id, Ic]
    ])
    m3 = torch.Tensor([
        [y1-y],
        [y-y0]
    ])
    return scale * torch.matmul( torch.matmul(m1, m2), m3 )

In [248]:
%%time
for i in range(100):
    res = bilinear_interpolate(image, x, y)

CPU times: user 35.4 ms, sys: 1e+03 µs, total: 36.4 ms
Wall time: 64.2 ms


In [249]:
%%time
for i in range(100):
    res = matrix_interpolation(image, x, y)

CPU times: user 31.9 ms, sys: 4.7 ms, total: 36.6 ms
Wall time: 50.7 ms


In [250]:
def matrix_interpolation_3c(im, x, y, dtype=torch.FloatTensor, dtype_long=torch.LongTensor):
    # Get four corner indicies
    x0 = torch.floor(x).type(dtype_long)
    x1 = x0 + 1
    y0 = torch.floor(y).type(dtype_long)
    y1 = y0 + 1
    # Clamp within h, w boundries
    x0 = torch.clamp(x0, 0, im.shape[1]-1)
    x1 = torch.clamp(x1, 0, im.shape[1]-1)
    y0 = torch.clamp(y0, 0, im.shape[0]-1)
    y1 = torch.clamp(y1, 0, im.shape[0]-1)
    # Get four corner pixel values
    Ia = im[x0, y0]
    Ib = im[x0, y1]
    Ic = im[x1, y0]
    Id = im[x1, y1]
    # Define matricies
    scale = 1 / ( (x1-x0) * (y1-y0) )
    m1 = torch.Tensor([x1-x, x-x0])
    m2 = torch.cat([Ib, Ia, Id, Ic]).reshape(2,2,3)
    m3 = torch.Tensor([
        [y1-y],
        [y-y0]
    ])
    return scale * torch.matmul( torch.matmul(m1, m2).t(), m3 ).t()

In [251]:
%%time
for i in range(100):
    for c in range(3):
        res = matrix_interpolation(image3c[:,:,c], x, y)

CPU times: user 99.4 ms, sys: 2.91 ms, total: 102 ms
Wall time: 263 ms


In [252]:
%%time
for i in range(100):
    res = matrix_interpolation_3c(image3c, x, y)

CPU times: user 42.1 ms, sys: 4.72 ms, total: 46.9 ms
Wall time: 193 ms


In [21]:
def matrix_interpolation_3cb(im, x, y, dtype=torch.FloatTensor, dtype_long=torch.LongTensor):
    # Get four corner indicies
    x0 = torch.floor(x).type(dtype_long)
    x1 = x0 + 1
    y0 = torch.floor(y).type(dtype_long)
    y1 = y0 + 1
    # Clamp within h, w boundries
    x0 = torch.clamp(x0, 0, im.shape[1]-1)
    x1 = torch.clamp(x1, 0, im.shape[1]-1)
    y0 = torch.clamp(y0, 0, im.shape[0]-1)
    y1 = torch.clamp(y1, 0, im.shape[0]-1)
    # Get four corner pixel values
    Ia = torch.cat([im[coord[0], coord[1], :].unsqueeze(0) for coord in zip(x0, y0)])
    Ib = torch.cat([im[coord[0], coord[1], :].unsqueeze(0) for coord in zip(x0, y1)])
    Ic = torch.cat([im[coord[0], coord[1], :].unsqueeze(0) for coord in zip(x1, y0)])
    Id = torch.cat([im[coord[0], coord[1], :].unsqueeze(0) for coord in zip(x1, y1)])
    # Define matricies
    scale = (1 / ( (x1-x0) * (y1-y0) ) ).flatten()
    m1 = torch.cat([ torch.sub(x1, x), torch.sub(x, x0)], dim=1)
    m2 = torch.stack([Ib, Ia, Id, Ic], dim=1).reshape(16,2,2,3)
    m3 = torch.cat([ torch.sub(y1, y), torch.sub(y, y0) ], dim=1)
    # Reshape for batch matmul
    m1 = m1.reshape(16,1,1,2).repeat(1,2,1,1)
    m3 = m3.reshape(16,1,2,1)
    return scale[:,None] * torch.matmul( torch.matmul(m1, m2).permute(0,3,2,1), m3 ).flatten(start_dim=1)

In [53]:
def batch_bli(im, x, y, channel_first=False, dtype=torch.FloatTensor, dtype_long=torch.LongTensor):
    # ensure channel last
    if channel_first:
        im = im.permute(0,2,3,1)
    batch = im.shape[0]
    num_points = x.shape[1]
    assert x.shape==y.shape
    # Get four corner indicies
    x0 = torch.floor(x).type(dtype_long)
    x1 = x0 + 1
    y0 = torch.floor(y).type(dtype_long)
    y1 = y0 + 1
    # Clamp within h, w boundries
    x0 = torch.clamp(x0, 0, im.shape[2]-1)
    x1 = torch.clamp(x1, 0, im.shape[2]-1)
    y0 = torch.clamp(y0, 0, im.shape[1]-1)
    y1 = torch.clamp(y1, 0, im.shape[1]-1)
    print(batch, x0.shape, y0.shape)
    # Get four corner pixel values
    Ia = torch.cat([im[b, x, y, :] for b in range(batch) for x, y in zip(x0[b], y0[b])])
    Ib = torch.cat([im[b, x, y, :] for b in range(batch) for x, y in zip(x0[b], y0[b])])
    Ic = torch.cat([im[b, x, y, :] for b in range(batch) for x, y in zip(x0[b], y0[b])])
    Id = torch.cat([im[b, x, y, :] for b in range(batch) for x, y in zip(x0[b], y0[b])])
    # Define matricies
    scale = (1 / ( (x1-x0) * (y1-y0) ) ).flatten()
    m1 = torch.cat([ torch.sub(x1, x), torch.sub(x, x0)], dim=1)
    m2 = torch.stack([Ib, Ia, Id, Ic], dim=1).reshape(batch*num_points,2,2,3)
    m3 = torch.cat([ torch.sub(y1, y), torch.sub(y, y0) ], dim=1)
    # Reshape for batch matmul
    m1 = m1.reshape(batch*num_points,1,1,2).repeat(1,2,1,1)
    m3 = m3.reshape(batch*num_points,1,2,1)
    return scale[:,None] * torch.matmul( torch.matmul(m1, m2).permute(0,3,2,1), m3 ).flatten(start_dim=1)

In [67]:
im = torch.arange(4).reshape(1,2,2,1).repeat(2,1,1,3)
print(im.shape)
mine = batch_bli(im, x_b.repeat(2,1), y_b.repeat(2,1))
theirs = torch.Tensor(tuple(bilinear_interpolate(im[b,:,:,c], x, y).item() for b in range(2) for c in range(3)))
print(mine, theirs)

torch.Size([2, 2, 2, 3])
2 torch.Size([32, 1]) torch.Size([32, 1])


RuntimeError: shape '[2, 1, 1, 2]' is invalid for input of size 64