In [None]:
from init_notebook import *

In [None]:
class Conv2dPoly(nn.Module):
    def __init__(
            self,
            order: int,
            in_channels: int,
            out_channels: int,
            kernel_size: int,
            stride: int = 1,
            padding: int = 0,
            transpose: bool = False,
    ):
        super().__init__()
        self._order = order
        self._in_channels = in_channels

        self.conv_split = nn.Conv2d(in_channels, in_channels * order, kernel_size=1, groups=in_channels)
        conv_class = nn.ConvTranspose2d if transpose else nn.Conv2d
        self.conv_combine = conv_class(in_channels * order, out_channels, kernel_size, stride, padding)

    def extra_repr(self):
        return f"order={self._order}"

    def forward(self, x):
        y = self.conv_split(x)

        y = torch.concat([
            y[..., i * self._in_channels: (i + 1) * self._in_channels, :, :] ** (i + 1)
            for i in range(self._order)
        ], dim=-3)
        print("Y", y)

        y = self.conv_combine(y)
        return y

m = Conv2dPoly(3, 3, 10, 3, padding=1)
print(f"params: {num_module_parameters(m):,}")

inp = torch.ones(1, 3, 5, 5)
outp = m(inp)
print(inp.shape, "->", outp.shape)
display(outp)
m