In [None]:
from init_notebook import *

In [None]:
class KANPolyLayer(nn.Module):
    """
    based on https://github.com/SciYu/KAE/blob/main/DenseLayerPack/KAE.py
    """
    def __init__(
            self,
            input_dim: int,
            out_dim: int,
            order: int,
            bias: bool = True,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.out_dim = out_dim
        self.order = order
        self.coeffs = nn.Parameter(torch.randn(out_dim, input_dim, order + 1) * 0.01)
        self.bias = None
        if bias:
            self.bias = nn.Parameter(torch.zeros(1, out_dim))

    def extra_repr(self):
        return f"input_dim={self.input_dim}, out_dim={self.out_dim}, order={self.order}, bias={self.bias is not None}"

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x_expanded = x.unsqueeze(1).expand(-1, self.out_dim, -1)

        y = torch.zeros((x.shape[0], self.out_dim), device=x.device, dtype=x.dtype)

        for i in range(self.order + 1):
            term = (x_expanded ** i) * self.coeffs[:, :, i]
            y += term.sum(dim=-1)

        if self.bias is not None:
            y = y + self.bias

        return y


class KanMLPLayer(nn.Module):

    def __init__(
            self,
            channels_in: int,
            channels_out: int,
            layer_type: str = "mlp",  # "mlp", "kanpoly<O>",
            activation: Union[None, str, Callable] = None,
            bias: bool = True,
            residual: bool = True,
    ):
        super().__init__()
        self._residual = residual and channels_in == channels_out
        self._layer_type = layer_type

        if layer_type == "mlp":
            self.module = nn.Linear(channels_in, channels_out, bias=bias)
        elif layer_type.startswith("kanpoly"):
            self.module = KANPolyLayer(channels_in, channels_out, order=int(layer_type[7:]), bias=bias)
        else:
            raise ValueError(f"Unknown layer_type '{layer_type}'")
            
        self.act = activation_to_module(activation)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.module(x)
        if self.act is not None:
            y = self.act(y)
        if self._residual:
            y = y + x
        return y

    def extra_repr(self):
        return f"layer_type='{self._layer_type}', residual={self._residual}"


class KanMLPStack(nn.Module):

    def __init__(
            self,
            channels: Tuple[int, ...],
            layer_type: Tuple[str, ...],
            activation: Tuple[Union[None, str, Callable], ...],
    ):
        super().__init__()
        self.stack = nn.Sequential()

        for (ch, next_ch), layer_type_, act_ in zip(
                zip(channels, channels[1:]),
                layer_type,
                activation,
                
        ):
            self.stack.append(KanMLPLayer(ch, next_ch, layer_type=layer_type_, activation=act_))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.stack(x)  
                            

class KanMLPAE(nn.Module):

    def __init__(
            self,
            image_shape: Tuple[int, int, int],
            encoder_channels: Tuple[int, ...],
            encoder_layer_type: Tuple[str, ...],
            encoder_activation: Tuple[Union[None, str, Callable], ...],
            decoder_channels: Tuple[int, ...],
            decoder_layer_type: Tuple[str, ...],
            decoder_activation: Tuple[Union[None, str, Callable], ...],
    ):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Flatten(-3),
            KanMLPStack(
                channels=[math.prod(image_shape), *encoder_channels],
                layer_type=encoder_layer_type,
                activation=encoder_activation,
            )
        )
            
        self.decoder = nn.Sequential(
            KanMLPStack(
                channels=[*decoder_channels, math.prod(image_shape)],
                layer_type=decoder_layer_type,
                activation=decoder_activation,
            ),
            Reshape(image_shape),
        )
        
    def forward(self, batch: torch.Tensor) -> torch.Tensor:
        y = self.encoder(batch)
        y = self.decoder(y)
        return y
        
module = KanMLPAE(
    image_shape=(3, 64, 64), 
    encoder_channels=[128, 64],
    encoder_layer_type=["mlp", "kanpoly3"],
    encoder_activation=["gelu", "none"],
    decoder_channels=[64, 128],
    decoder_layer_type=["mlp", "mlp"],
    decoder_activation=["gelu", "none"],
)
print(f"params: {num_module_parameters(module):,}")
inp = torch.ones(2, 3, 64, 64)
outp = module(inp)
print("outp", outp.shape)
display(module)