[Kaggle Reference (Keras)](https://www.kaggle.com/code/cdeotte/wavenet-starter-lb-0-52)

In [None]:
import jax.numpy as jnp
import flax.linen as ln

In [None]:
class WaveBlock(ln.Module):
    n_features : int=64
    kernel_size : int=1
    n_dilations : int=4
    
    @ln.compact
    def __call__(self, x):
        x = ln.Conv(
            self.n_features,
            1,
        )(x)
        res_x = x
        for kernel_dilation in (2**i for i in range(self.n_dilations)):
            tanh_x = ln.Conv(
                self.n_features, 
                self.kernel_size, 
                kernel_dilation=kernel_dilation
            )(x)
            tanh_x = jnp.tanh(tanh_x)
            
            sigm_x = ln.Conv(
                self.n_features,
                self.kernel_size,
                kernel_dilation=kernel_dilation
            )(x)
            sigm_x = jnp.sigmoid(sigm_x)
            x = jnp.multiply(tanh_x, sigm_x)
            x = ln.Conv(
                self.n_features,
                1,
            )(x)
            res_x = jnp.add(res_x, x)
        return res_x

In [None]:
class WavePath(ln.Module):
    kernel_size : int=3
    init_features : int=8
    init_dilations : int=12
    n_blocks : int=4
    
    @ln.compact
    def __call__(self, x):
        n_features = self.init_features
        n_dilations = self.init_dilations
        for _ in range(self.n_blocks):
            x = WaveBlock(n_features, self.kernel_size, n_dilations)(x)
            n_features *= 2
            n_dilations = max(1, n_dilations - 4)
        return x

In [None]:
class WaveNet(ln.Module):
    *args : tuple
    **kwargs : dict
    
    @ln.compact
    def __call__(self, x):
        wave_path = WavePath(*args, **kwargs)
    zs = [
        jnp.mean([
            jnp.mean(wavenet(x[:,:,i;i+1]), axis=-1),
            jnp.mean(wavenet(x[:,:,i+1:i+2]), axis=-1)
        ])
        for i
        in range(0, 8, 2)
    ]
    y = jnp.concatenate(zs)
    y = ln.Dense(64)(y)
    y = jnp.relu(y)
    y = ln.Dense(6)
    y = jnp.relu(y)
    return y