# Upsampling

Here we want to create the operator for nn.upsample(scale_factor=2, mode='bilinear', align_corners=True)

In [2]:
import torch
from devito import Operator, Grid, Function, dimensions, Eq, Inc
import numpy as np
import sympy

In [4]:
#init input same as in pytorch tutorial
input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)

In [6]:
m = torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
torch_out = m(input)
torch_out

tensor([[[[1.0000, 1.3333, 1.6667, 2.0000],
          [1.6667, 2.0000, 2.3333, 2.6667],
          [2.3333, 2.6667, 3.0000, 3.3333],
          [3.0000, 3.3333, 3.6667, 4.0000]]]])

In [29]:
inp = torch.rand((1,1,3,3))
print(inp)
n = torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
torch_out = n(inp)
torch_out

tensor([[[[0.6250, 0.1692, 0.5333],
          [0.1604, 0.4394, 0.0758],
          [0.4560, 0.8314, 0.4342]]]])


tensor([[[[0.6250, 0.5111, 0.2832, 0.2602, 0.4423, 0.5333],
          [0.5089, 0.4408, 0.3048, 0.2823, 0.3734, 0.4189],
          [0.2765, 0.3004, 0.3480, 0.3264, 0.2356, 0.1901],
          [0.2343, 0.3101, 0.4616, 0.4444, 0.2584, 0.1654],
          [0.3821, 0.4699, 0.6456, 0.6362, 0.4418, 0.3446],
          [0.4560, 0.5499, 0.7376, 0.7321, 0.5335, 0.4342]]]])

In [79]:
def resize_bilinear(input, scale_factor):
    source_H = input.shape[0]
    source_W = input.shape[1]
    #3 , 3
    resized_H = int(source_H * scale_factor)
    resized_W = int(source_W * scale_factor)
    #6,6

    output = np.zeros((resized_H, resized_W), dtype=np.float32)
    
    def read_pixel(x, y):
        #Given an interval, values outside the interval are clipped to the interval edges
        #np.minimum(a_max, np.maximum(a, a_min)).
        # array, min, max
        x = np.clip(x, 0, source_W - 1)
        y = np.clip(y, 0, source_H - 1)
        print("x ,y, indeces", x,y,"input[y, x]" , input[y, x])
        return input[y, x]

    def bilinear_interpolate(x, y):
        x1 = int(np.floor(x))
        print("x:", x,"x1:", x1)
        x2 = x1 + 1

        y1 = int(np.floor(y))
        y2 = y1 + 1

        P11 = read_pixel(x1, y1)
        P12 = read_pixel(x1, y2)
        P21 = read_pixel(x2, y1)
        P22 = read_pixel(x2, y2)
        
        return (P11 * (x2 - x) * (y2 - y) + 
                P12 * (x2 - x) * (y - y1) + 
                P21 * (x - x1) * (y2 - y) + 
                P22 * (x - x1) * (y - y1)) / ((x2 - x1) * (y2 - y1))

    for dst_y in range(resized_H):
        for dst_x in range(resized_W):
            src_x = (dst_x + 0.5) / scale_factor - 0.5
            src_y = (dst_y + 0.5) / scale_factor - 0.5
            output[dst_y, dst_x] = bilinear_interpolate(src_x, src_y)

    return output
python_out = resize_bilinear(inp[0][0].numpy(),2)
np.allclose(torch_out, python_out)

x: -0.25 x1: -1
x ,y, indeces 0 0 input[y, x] 0.6250223
x ,y, indeces 0 0 input[y, x] 0.6250223
x ,y, indeces 0 0 input[y, x] 0.6250223
x ,y, indeces 0 0 input[y, x] 0.6250223
x: 0.25 x1: 0
x ,y, indeces 0 0 input[y, x] 0.6250223
x ,y, indeces 0 0 input[y, x] 0.6250223
x ,y, indeces 1 0 input[y, x] 0.16919512
x ,y, indeces 1 0 input[y, x] 0.16919512
x: 0.75 x1: 0
x ,y, indeces 0 0 input[y, x] 0.6250223
x ,y, indeces 0 0 input[y, x] 0.6250223
x ,y, indeces 1 0 input[y, x] 0.16919512
x ,y, indeces 1 0 input[y, x] 0.16919512
x: 1.25 x1: 1
x ,y, indeces 1 0 input[y, x] 0.16919512
x ,y, indeces 1 0 input[y, x] 0.16919512
x ,y, indeces 2 0 input[y, x] 0.53328013
x ,y, indeces 2 0 input[y, x] 0.53328013
x: 1.75 x1: 1
x ,y, indeces 1 0 input[y, x] 0.16919512
x ,y, indeces 1 0 input[y, x] 0.16919512
x ,y, indeces 2 0 input[y, x] 0.53328013
x ,y, indeces 2 0 input[y, x] 0.53328013
x: 2.25 x1: 2
x ,y, indeces 2 0 input[y, x] 0.53328013
x ,y, indeces 2 0 input[y, x] 0.53328013
x ,y, indeces 2 0 in

True

In [45]:
inparray = inp[0][0].numpy()

In [46]:
inparray

array([[0.6250223 , 0.16919512, 0.53328013],
       [0.16036487, 0.43940836, 0.07576615],
       [0.4560393 , 0.8314114 , 0.43418592]], dtype=float32)

In [51]:
np.allclose(torch_out, python_out)

True

In [None]:
wheited sum but the weights depends on how far you are

In [54]:
import inspect

In [None]:
inspect.