In [1]:
import torch
import torch.nn as nn
import tensorflow as tf

In [2]:
class CustomUpScale2d(nn.Module):
    """

    Upscales an input tensor by specified factor. 
    Upscaling done by torch.tile: https://pytorch.org/docs/stable/generated/torch.tile.html

    """

    def __init__(self, factor: int = 2):
        """
        
        :param factor: scaling factor.
        
        """
        super(CustomUpScale2d, self).__init__()

        assert factor >= 1, "Scaling factor must be greater than 0!"

        self.factor = factor


    def forward(self, x) -> torch.Tensor:
        """
        
        :param x: torch.Tensor
            :shape: (b, c, h, w)
        
        :return : torch.Tensor
            :shape: (b, c, h * self.factor, w * self.factor)

        """

        shape = x.shape

        # input : (b, c, h, w)
        # output: (b, c, h, 1, w, 1)
        x = torch.reshape(x, (-1, shape[1], shape[2], 1, shape[3], 1))

        # input : (b, c, h, 1, w, 1)
        # output: (b, c, h, self.factor, w, self.factor)
        x = torch.tile(x, (1, 1, 1, self.factor, 1, self.factor))

        # input : (b, c, h, self.factor, w, self.factor)
        # output: (b, c, h * self.factor, w * self.factor)
        x = torch.reshape(x, (-1, shape[1], shape[2] * self.factor, shape[3] * self.factor)) 

        return x

In [19]:
def upscale2d(x, factor=2):
    assert isinstance(factor, int) and factor >= 1
    
    if factor == 1: return x
    
    s = x.shape
    x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1])
    x = tf.tile(x, [1, 1, 1, factor, 1, factor])
    x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor])
    return x

In [20]:
a = torch.rand((2, 2, 2, 2), requires_grad=False)

In [27]:
a.shape

torch.Size([2, 2, 2, 2])

In [28]:
upscale2d(a)

<tf.Tensor: shape=(2, 2, 4, 4), dtype=float32, numpy=
array([[[[0.6254536 , 0.6254536 , 0.3439632 , 0.3439632 ],
         [0.6254536 , 0.6254536 , 0.3439632 , 0.3439632 ],
         [0.12942779, 0.12942779, 0.3274418 , 0.3274418 ],
         [0.12942779, 0.12942779, 0.3274418 , 0.3274418 ]],

        [[0.00451815, 0.00451815, 0.40985316, 0.40985316],
         [0.00451815, 0.00451815, 0.40985316, 0.40985316],
         [0.92387766, 0.92387766, 0.34538257, 0.34538257],
         [0.92387766, 0.92387766, 0.34538257, 0.34538257]]],


       [[[0.5654647 , 0.5654647 , 0.71000516, 0.71000516],
         [0.5654647 , 0.5654647 , 0.71000516, 0.71000516],
         [0.29286325, 0.29286325, 0.7952678 , 0.7952678 ],
         [0.29286325, 0.29286325, 0.7952678 , 0.7952678 ]],

        [[0.6335821 , 0.6335821 , 0.36711532, 0.36711532],
         [0.6335821 , 0.6335821 , 0.36711532, 0.36711532],
         [0.27536094, 0.27536094, 0.60316616, 0.60316616],
         [0.27536094, 0.27536094, 0.60316616, 0.60316

In [29]:
scale = CustomUpScale2d(factor=2)

In [30]:
scale(a)

tensor([[[[0.6255, 0.6255, 0.3440, 0.3440],
          [0.6255, 0.6255, 0.3440, 0.3440],
          [0.1294, 0.1294, 0.3274, 0.3274],
          [0.1294, 0.1294, 0.3274, 0.3274]],

         [[0.0045, 0.0045, 0.4099, 0.4099],
          [0.0045, 0.0045, 0.4099, 0.4099],
          [0.9239, 0.9239, 0.3454, 0.3454],
          [0.9239, 0.9239, 0.3454, 0.3454]]],


        [[[0.5655, 0.5655, 0.7100, 0.7100],
          [0.5655, 0.5655, 0.7100, 0.7100],
          [0.2929, 0.2929, 0.7953, 0.7953],
          [0.2929, 0.2929, 0.7953, 0.7953]],

         [[0.6336, 0.6336, 0.3671, 0.3671],
          [0.6336, 0.6336, 0.3671, 0.3671],
          [0.2754, 0.2754, 0.6032, 0.6032],
          [0.2754, 0.2754, 0.6032, 0.6032]]]])