In [49]:
import torch

In [637]:
image = torch.rand((5,5,1))
image3c = torch.rand((5,5,3))
image3cb = torch.rand((16,5,5,3))
x, y = torch.tensor(0.5), torch.tensor(0.5)
x_b, y_b= torch.rand((16,1)), torch.rand((16,1))

In [300]:
def bilinear_interpolate(im, x, y, dtype=torch.FloatTensor, dtype_long=torch.LongTensor):
    # Get four corner indicies
    x0 = torch.floor(x).type(dtype_long)
    x1 = x0 + 1
    y0 = torch.floor(y).type(dtype_long)
    y1 = y0 + 1
    # Clamp within h, w boundries
    x0 = torch.clamp(x0, 0, im.shape[1]-1)
    x1 = torch.clamp(x1, 0, im.shape[1]-1)
    y0 = torch.clamp(y0, 0, im.shape[0]-1)
    y1 = torch.clamp(y1, 0, im.shape[0]-1)
    # Get four corner pixel values
    Ia = im[x0, y0]
    Ib = im[x0, y1]
    Ic = im[x1, y0]
    Id = im[x1, y1]
    # Weight four corner pixel values
    wa = (x1.type(dtype)-x) * (y1.type(dtype)-y)
    wc = (x1.type(dtype)-x) * (y-y0.type(dtype))
    wb = (x-x0.type(dtype)) * (y1.type(dtype)-y)
    wd = (x-x0.type(dtype)) * (y-y0.type(dtype))
    return torch.t((torch.t(Ia)*wa)) + torch.t(torch.t(Ib)*wb) + torch.t(torch.t(Ic)*wc) + torch.t(torch.t(Id)*wd)

In [301]:
def matrix_interpolation(im, x, y, dtype=torch.FloatTensor, dtype_long=torch.LongTensor):
    # Get four corner indicies
    x0 = torch.floor(x).type(dtype_long)
    x1 = x0 + 1
    y0 = torch.floor(y).type(dtype_long)
    y1 = y0 + 1
    # Clamp within h, w boundries
    x0 = torch.clamp(x0, 0, im.shape[1]-1)
    x1 = torch.clamp(x1, 0, im.shape[1]-1)
    y0 = torch.clamp(y0, 0, im.shape[0]-1)
    y1 = torch.clamp(y1, 0, im.shape[0]-1)
    # Get four corner pixel values
    Ia = im[x0, y0]
    Ib = im[x0, y1]
    Ic = im[x1, y0]
    Id = im[x1, y1]
    # Define matricies
    scale = 1 / ( (x1-x0) * (y1-y0) )
    m1 = torch.Tensor([x1-x, x-x0])
    m2 = torch.Tensor([
        [Ib, Ia],
        [Id, Ic]
    ])
    m3 = torch.Tensor([
        [y1-y],
        [y-y0]
    ])
    return scale * torch.matmul( torch.matmul(m1, m2), m3 )

In [302]:
%%time
for i in range(100):
    res = bilinear_interpolate(image, x, y)

CPU times: user 32.9 ms, sys: 2.26 ms, total: 35.2 ms
Wall time: 53.1 ms


In [303]:
%%time
for i in range(100):
    res = matrix_interpolation(image, x, y)

CPU times: user 39 ms, sys: 1.29 ms, total: 40.3 ms
Wall time: 89.1 ms


In [621]:
def matrix_interpolation_3c(im, x, y, dtype=torch.FloatTensor, dtype_long=torch.LongTensor):
    # Get four corner indicies
    x0 = torch.floor(x).type(dtype_long)
    x1 = x0 + 1
    y0 = torch.floor(y).type(dtype_long)
    y1 = y0 + 1
    # Clamp within h, w boundries
    x0 = torch.clamp(x0, 0, im.shape[1]-1)
    x1 = torch.clamp(x1, 0, im.shape[1]-1)
    y0 = torch.clamp(y0, 0, im.shape[0]-1)
    y1 = torch.clamp(y1, 0, im.shape[0]-1)
    # Get four corner pixel values
    Ia = im[x0, y0]
    Ib = im[x0, y1]
    Ic = im[x1, y0]
    Id = im[x1, y1]
    # Define matricies
    scale = 1 / ( (x1-x0) * (y1-y0) )
    m1 = torch.Tensor([x1-x, x-x0])
    m2 = torch.cat([Ib, Ia, Id, Ic]).reshape(2,2,3)
    m3 = torch.Tensor([
        [y1-y],
        [y-y0]
    ])
    return scale * torch.matmul( torch.matmul(m1, m2).t(), m3 ).t()

In [622]:
%%time
for i in range(100):
    for c in range(3):
        res = matrix_interpolation(image3c[:,:,c], x, y)

CPU times: user 98.7 ms, sys: 3.94 ms, total: 103 ms
Wall time: 252 ms


In [623]:
%%time
for i in range(100):
    res = matrix_interpolation_3c(image3c, x, y)

CPU times: user 46.4 ms, sys: 1.93 ms, total: 48.3 ms
Wall time: 193 ms


In [650]:
def matrix_interpolation_3cb(im, x, y, dtype=torch.FloatTensor, dtype_long=torch.LongTensor):
    # Get four corner indicies
    x0 = torch.floor(x).type(dtype_long)
    x1 = x0 + 1
    y0 = torch.floor(y).type(dtype_long)
    y1 = y0 + 1
    # Clamp within h, w boundries
    x0 = torch.clamp(x0, 0, im.shape[1]-1)
    x1 = torch.clamp(x1, 0, im.shape[1]-1)
    y0 = torch.clamp(y0, 0, im.shape[0]-1)
    y1 = torch.clamp(y1, 0, im.shape[0]-1)
    # Get four corner pixel values
    Ia = torch.cat([im[idx, coord[0], coord[1], :].unsqueeze(0) for idx, coord in enumerate(zip(x0, y0))])
    Ib = torch.cat([im[idx, coord[0], coord[1], :].unsqueeze(0) for idx, coord in enumerate(zip(x0, y1))])
    Ic = torch.cat([im[idx, coord[0], coord[1], :].unsqueeze(0) for idx, coord in enumerate(zip(x1, y0))])
    Id = torch.cat([im[idx, coord[0], coord[1], :].unsqueeze(0) for idx, coord in enumerate(zip(x1, y1))])
    # Define matricies
    scale = (1 / ( (x1-x0) * (y1-y0) ) ).flatten()
    m1 = torch.cat([ torch.sub(x1, x), torch.sub(x, x0)], dim=1).reshape(16,1,1,2).repeat(1,2,1,1)
    m2 = torch.cat([Ib, Ia, Id, Ic]).reshape(16,2,2,3)
    m3 = torch.cat([ torch.sub(y1, y), torch.sub(y, y0) ], dim=1).reshape(16,1,2,1)
    return scale[:,None] * torch.matmul( torch.matmul(m1, m2).permute(0,3,2,1), m3 ).flatten(start_dim=1)

In [651]:
%%time
for b in range(16):
    for xy in range(16):
        res = matrix_interpolation_3c(image3cb[b], x, y)

CPU times: user 91.9 ms, sys: 3.92 ms, total: 95.8 ms
Wall time: 202 ms


In [654]:
%%time
res = matrix_interpolation_3cb(image3cb, x_b, y_b)

CPU times: user 14.1 ms, sys: 612 µs, total: 14.8 ms
Wall time: 6.52 ms


In [655]:
matrix_interpolation_3c(image3cb[0,:,:,:], x, y)

tensor([[0.6300, 0.5568, 0.6429]])

In [656]:
matrix_interpolation_3cb(image3cb, x.repeat(16).reshape(16,1), y.repeat(16).reshape(16,1))

tensor([[0.6692, 0.8044, 0.4818],
        [0.7247, 0.1245, 0.4228],
        [0.5886, 0.5194, 0.5998],
        [0.6495, 0.5602, 0.6638],
        [0.5683, 0.6709, 0.4474],
        [0.2950, 0.5295, 0.7453],
        [0.4744, 0.6476, 0.7048],
        [0.4726, 0.4456, 0.5612],
        [0.5863, 0.5360, 0.5158],
        [0.7384, 0.3854, 0.5860],
        [0.4490, 0.5046, 0.4237],
        [0.2916, 0.5216, 0.5414],
        [0.5322, 0.3611, 0.6640],
        [0.4352, 0.5196, 0.4954],
        [0.6002, 0.5054, 0.5964],
        [0.3655, 0.6156, 0.5263]])