In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
torch.set_default_dtype(torch.float64)

In [3]:
class PUpsample(nn.Module):
    def __init__(self, scale_factor):
        super(PUpsample, self).__init__()
        self.scale_factor = scale_factor

    def forward(self, x, shifts):
        w = x.new_zeros(self.scale_factor)
        w[0] = 1
        output = F.conv_transpose1d(x, w.expand(x.size(1), 1, self.scale_factor), stride = self.scale_factor, groups = x.size(1))
        indices = (torch.arange(output.shape[2])[None, None, :].repeat(1, 2, 1) - shifts[:, None, None]) % output.shape[2]
        output = torch.gather(output, 2, indices)
        return output

In [4]:
x = torch.from_numpy(np.random.random((3, 2, 4)))
shifts = torch.from_numpy(np.random.randint(0, 2, size = 3))

In [5]:
x

tensor([[[0.9064, 0.6192, 0.6371, 0.9570],
         [0.3092, 0.3479, 0.1468, 0.3169]],

        [[0.8517, 0.2326, 0.6052, 0.1101],
         [0.7037, 0.2392, 0.6133, 0.7115]],

        [[0.0872, 0.8927, 0.1336, 0.4283],
         [0.2905, 0.8835, 0.1832, 0.7394]]])

In [6]:
shifts

tensor([1, 0, 1], dtype=torch.int32)

In [7]:
model = PUpsample(2)

In [8]:
output_x = model(x, shifts)

In [9]:
output_x

tensor([[[0.0000, 0.9064, 0.0000, 0.6192, 0.0000, 0.6371, 0.0000, 0.9570],
         [0.0000, 0.3092, 0.0000, 0.3479, 0.0000, 0.1468, 0.0000, 0.3169]],

        [[0.8517, 0.0000, 0.2326, 0.0000, 0.6052, 0.0000, 0.1101, 0.0000],
         [0.7037, 0.0000, 0.2392, 0.0000, 0.6133, 0.0000, 0.7115, 0.0000]],

        [[0.0000, 0.0872, 0.0000, 0.8927, 0.0000, 0.1336, 0.0000, 0.4283],
         [0.0000, 0.2905, 0.0000, 0.8835, 0.0000, 0.1832, 0.0000, 0.7394]]])