In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
torch.set_default_dtype(torch.float64)

In [63]:
class PMaxPool1d(nn.Module):
    def __init__(self, kernel_size):
        super(PMaxPool1d, self).__init__()

        self.kernel_size = kernel_size

    def forward(self, x):
        rolling = torch.stack([F.max_pool1d(torch.roll(x, i, dims = -1), self.kernel_size, self.kernel_size) for i in range(self.kernel_size)])
        shifts = torch.argmax(torch.std(rolling.view(self.kernel_size, x.size(0), -1), dim = -1), dim = 0)
        output = torch.take_along_dim(rolling, shifts.view(1, -1, 1, 1), dim = 0).squeeze(dim = 0)
        return output, shifts

In [96]:
x = torch.from_numpy(np.random.random((3, 2, 10)))

In [97]:
x

tensor([[[0.6767, 0.9651, 0.7962, 0.5681, 0.3010, 0.6574, 0.3765, 0.6910,
          0.1010, 0.1385],
         [0.0229, 0.7823, 0.4980, 0.8121, 0.4782, 0.4385, 0.2437, 0.6675,
          0.5861, 0.8954]],

        [[0.4388, 0.1923, 0.1977, 0.8696, 0.5331, 0.1061, 0.6925, 0.3167,
          0.4269, 0.1175],
         [0.0770, 0.6403, 0.4687, 0.8042, 0.8443, 0.6269, 0.6245, 0.3739,
          0.4007, 0.4727]],

        [[0.7495, 0.8998, 0.2174, 0.7644, 0.2314, 0.3336, 0.5375, 0.0495,
          0.7012, 0.9177],
         [0.8719, 0.9034, 0.9668, 0.1388, 0.5222, 0.6102, 0.3155, 0.6227,
          0.0860, 0.2184]]])

In [98]:
model = PMaxPool1d(2)

In [99]:
output_x = model(x)

In [105]:
y = torch.roll(x, 2, dims = -1)

In [106]:
y

tensor([[[0.1010, 0.1385, 0.6767, 0.9651, 0.7962, 0.5681, 0.3010, 0.6574,
          0.3765, 0.6910],
         [0.5861, 0.8954, 0.0229, 0.7823, 0.4980, 0.8121, 0.4782, 0.4385,
          0.2437, 0.6675]],

        [[0.4269, 0.1175, 0.4388, 0.1923, 0.1977, 0.8696, 0.5331, 0.1061,
          0.6925, 0.3167],
         [0.4007, 0.4727, 0.0770, 0.6403, 0.4687, 0.8042, 0.8443, 0.6269,
          0.6245, 0.3739]],

        [[0.7012, 0.9177, 0.7495, 0.8998, 0.2174, 0.7644, 0.2314, 0.3336,
          0.5375, 0.0495],
         [0.0860, 0.2184, 0.8719, 0.9034, 0.9668, 0.1388, 0.5222, 0.6102,
          0.3155, 0.6227]]])

In [107]:
output_y = model(y)

In [108]:
output_x

(tensor([[[0.9651, 0.7962, 0.6574, 0.6910, 0.1385],
          [0.7823, 0.8121, 0.4782, 0.6675, 0.8954]],
 
         [[0.4388, 0.1977, 0.8696, 0.6925, 0.4269],
          [0.4727, 0.6403, 0.8443, 0.6269, 0.4007]],
 
         [[0.8998, 0.7644, 0.3336, 0.5375, 0.9177],
          [0.9034, 0.9668, 0.6102, 0.6227, 0.2184]]]),
 tensor([0, 1, 0]))

In [109]:
output_y

(tensor([[[0.1385, 0.9651, 0.7962, 0.6574, 0.6910],
          [0.8954, 0.7823, 0.8121, 0.4782, 0.6675]],
 
         [[0.4269, 0.4388, 0.1977, 0.8696, 0.6925],
          [0.4007, 0.4727, 0.6403, 0.8443, 0.6269]],
 
         [[0.9177, 0.8998, 0.7644, 0.3336, 0.5375],
          [0.2184, 0.9034, 0.9668, 0.6102, 0.6227]]]),
 tensor([0, 1, 0]))