# Periodic 1D CNN Classes

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
torch.set_default_dtype(torch.float64)

In [3]:
class PConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, padding):
        super(PConv1d, self).__init__()
        
        self.padding = (padding, padding)
        self.conv1d = nn.Conv1d(in_channels, out_channels, 2*padding + 1)

    def forward(self, x):
        output = F.pad(x, self.padding, mode = 'circular')
        output = self.conv1d(output)
        return output

In [4]:
class PMaxPool1d(nn.Module):
    def __init__(self, kernel_size):
        super(PMaxPool1d, self).__init__()
        self.kernel_size = kernel_size
        self.maxpool1d = nn.MaxPool1d(self.kernel_size)

    def forward(self, x):
        if x.shape[-1] > 2:
            maxpooled_xs = torch.stack([self.maxpool1d(torch.roll(x, i, dims = -1)) for i in range(self.kernel_size)])
            std_maxpooled_xs = torch.std(maxpooled_xs.view(self.kernel_size, x.shape[0], -1), dim = -1)
            indices = torch.argmax(std_maxpooled_xs, dim = 0, keepdim = True)
            output = torch.take_along_dim(maxpooled_xs, indices[:, :, None, None], dim = 0).squeeze(dim = 0)
        else:
            output = self.maxpool1d(x)
        return output

In [5]:
import numpy as np

x = torch.from_numpy(np.random.random(6).reshape(1, 1, -1))

# Test Convolution

In [6]:
pconv = PConv1d(1, 2, 2)
conv = nn.Conv1d(1, 2, 5, padding = 2)

In [7]:
pconv_x = pconv(x)
roll_pconv_x = torch.roll(pconv(torch.roll(x, 1, dims = -1)), -1, dims = -1)

In [8]:
conv_x = conv(x)
roll_conv_x = torch.roll(conv(torch.roll(x, 1, dims = -1)), -1, dims = -1)

In [9]:
torch.mean((pconv_x - roll_pconv_x)**2).item()**0.5

0.0

In [10]:
torch.mean((conv_x - roll_conv_x)**2).item()**0.5

0.17003060432833264

# Test Max Pooling

In [11]:
pmaxpool = PMaxPool1d(2)
maxpool = nn.MaxPool1d(2)

In [12]:
pmaxpool_x = pmaxpool(x)
roll_pmaxpool_x = pmaxpool(torch.roll(x, 1, dims = -1))

In [13]:
maxpool_x = maxpool(x)
roll_maxpool_x = maxpool(torch.roll(x, 1, dims = -1))

In [14]:
x

tensor([[[0.3471, 0.7069, 0.7339, 0.8489, 0.2652, 0.4307]]])

In [15]:
pmaxpool_x

tensor([[[0.4307, 0.7339, 0.8489]]])

In [16]:
roll_pmaxpool_x

tensor([[[0.4307, 0.7339, 0.8489]]])

In [17]:
torch.mean((pmaxpool_x - roll_pmaxpool_x)**2).item()**0.5

0.0

In [18]:
torch.mean((maxpool_x - roll_maxpool_x)**2).item()**0.5

0.2968667643470012