In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
torch.set_default_dtype(torch.float64)

In [45]:
class PMaxPoolnd(nn.Module):
    def __init__(self, kernel_size, n):
        super(PMaxPoolnd, self).__init__()

        self.kernel_size = kernel_size
        self.n = n
        self.shifts = torch.cartesian_prod(*((torch.arange(kernel_size),)*n)).tolist()
        self.max_pools = {1: F.max_pool1d, 2: F.max_pool2d, 3: F.max_pool3d}
        self.roll_dims = {1: (-1), 2: (-2, -1), 3: (-3, -2, -1)}

    def forward(self, x):
        rolling = torch.stack([self.max_pools[self.n](torch.roll(x, shift, dims = self.roll_dims[self.n]), 
                                                      self.kernel_size, 
                                                      self.kernel_size) 
                               for shift in self.shifts])
        picked = torch.argmax(torch.sum(rolling.view(self.kernel_size**self.n, x.size(0), -1), dim = -1), dim = 0)
        output = torch.take_along_dim(rolling, picked.view(1, -1, 1, *((1,)*self.n)), dim = 0).squeeze(dim = 0)
        return output, torch.tensor(self.shifts)[picked]