In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
torch.set_default_dtype(torch.float64)

In [3]:
class PMaxPool1d(nn.Module):
    def __init__(self, kernel_size):
        super(PMaxPool1d, self).__init__()

        self.kernel_size = kernel_size

    def forward(self, x):
        rolling = torch.stack([F.max_pool1d(torch.roll(x, i, dims = -1), self.kernel_size, self.kernel_size) for i in range(self.kernel_size)])
        shifts = torch.argmax(torch.std(rolling.view(self.kernel_size, x.size(0), -1), dim = -1), dim = 0)
        output = torch.take_along_dim(rolling, shifts.view(1, -1, 1, 1), dim = 0).squeeze(dim = 0)
        return output, shifts

In [4]:
x = torch.from_numpy(np.random.random((3, 2, 10)))

In [5]:
x

tensor([[[0.5206, 0.2159, 0.4187, 0.3348, 0.2149, 0.2942, 0.6424, 0.4777,
          0.7326, 0.7020],
         [0.8080, 0.1481, 0.5421, 0.4342, 0.6052, 0.2413, 0.8262, 0.8674,
          0.1024, 0.3740]],

        [[0.6393, 0.7396, 0.3426, 0.6652, 0.2888, 0.4326, 0.4716, 0.0042,
          0.1273, 0.6447],
         [0.5164, 0.3286, 0.6494, 0.9759, 0.8425, 0.8491, 0.7484, 0.7740,
          0.0022, 0.2295]],

        [[0.2749, 0.1573, 0.2811, 0.5702, 0.2340, 0.6992, 0.4367, 0.4872,
          0.3189, 0.5648],
         [0.7371, 0.0755, 0.9873, 0.5832, 0.4614, 0.6994, 0.5516, 0.9964,
          0.7285, 0.1382]]])

In [6]:
model = PMaxPool1d(2)

In [7]:
output_x = model(x)

In [8]:
y = torch.roll(x, 2, dims = -1)

In [9]:
y

tensor([[[0.7326, 0.7020, 0.5206, 0.2159, 0.4187, 0.3348, 0.2149, 0.2942,
          0.6424, 0.4777],
         [0.1024, 0.3740, 0.8080, 0.1481, 0.5421, 0.4342, 0.6052, 0.2413,
          0.8262, 0.8674]],

        [[0.1273, 0.6447, 0.6393, 0.7396, 0.3426, 0.6652, 0.2888, 0.4326,
          0.4716, 0.0042],
         [0.0022, 0.2295, 0.5164, 0.3286, 0.6494, 0.9759, 0.8425, 0.8491,
          0.7484, 0.7740]],

        [[0.3189, 0.5648, 0.2749, 0.1573, 0.2811, 0.5702, 0.2340, 0.6992,
          0.4367, 0.4872],
         [0.7285, 0.1382, 0.7371, 0.0755, 0.9873, 0.5832, 0.4614, 0.6994,
          0.5516, 0.9964]]])

In [10]:
output_y = model(y)

In [11]:
output_x

(tensor([[[0.5206, 0.4187, 0.2942, 0.6424, 0.7326],
          [0.8080, 0.5421, 0.6052, 0.8674, 0.3740]],
 
         [[0.6447, 0.7396, 0.6652, 0.4716, 0.1273],
          [0.5164, 0.6494, 0.9759, 0.8491, 0.7740]],
 
         [[0.2749, 0.5702, 0.6992, 0.4872, 0.5648],
          [0.7371, 0.9873, 0.6994, 0.9964, 0.7285]]]),
 tensor([0, 1, 0]))

In [12]:
output_y

(tensor([[[0.7326, 0.5206, 0.4187, 0.2942, 0.6424],
          [0.3740, 0.8080, 0.5421, 0.6052, 0.8674]],
 
         [[0.1273, 0.6447, 0.7396, 0.6652, 0.4716],
          [0.7740, 0.5164, 0.6494, 0.9759, 0.8491]],
 
         [[0.5648, 0.2749, 0.5702, 0.6992, 0.4872],
          [0.7285, 0.7371, 0.9873, 0.6994, 0.9964]]]),
 tensor([0, 1, 0]))