[Kaggle Reference (Keras)](https://www.kaggle.com/code/cdeotte/wavenet-starter-lb-0-52)

In [1]:
import jax.numpy as jnp
import flax.linen as ln

In [2]:
class WaveBlock(ln.Module):
    n_features : int=64
    kernel_size : int=1
    n_dilations : int=4
    
    @ln.compact
    def __call__(self, x):
        x = ln.Conv(
            self.n_features,
            1,
        )(x)
        res_x = x
        for kernel_dilation in (2**i for i in range(self.n_dilations)):
            # gating ###########################
            tanh_x = ln.Conv(
                self.n_features, 
                self.kernel_size, 
                kernel_dilation=kernel_dilation
            )(x)
            tanh_x = jnp.tanh(tanh_x)
            
            sigm_x = ln.Conv(
                self.n_features,
                self.kernel_size,
                kernel_dilation=kernel_dilation
            )(x)
            sigm_x = jnp.sigmoid(sigm_x)
            
            x = jnp.multiply(tanh_x, sigm_x)
            ###################################
            x = ln.Conv(
                self.n_features,
                1,
            )(x)
            res_x = jnp.add(res_x, x)
        return res_x

In [3]:
class WavePath(ln.Module):
    kernel_size : int=3
    init_features : int=8
    init_dilations : int=12
    n_blocks : int=4
    
    @ln.compact
    def __call__(self, x):
        n_features = self.init_features
        n_dilations = self.init_dilations
        for _ in range(self.n_blocks):
            x = WaveBlock(n_features, self.kernel_size, n_dilations)(x)
            n_features *= 2
            n_dilations = max(1, n_dilations - 4)
        return x

In [6]:
class WaveNet(ln.Module):
    embed : bool=False
    output_size : int=1
    feature_group_size : int=2
    dense_size : int=64
    path_kwargs : dict=dict()
    
    @ln.compact
    def __call__(self, x):
        wave_path = WavePath(**self.path_kwargs)
        n_features = x.shape[-1]
        
        remainder = n_features % self.feature_group_size
        if remainder:
            raise ValueError(f"""
                Expected number of features to be divisible by {self.feature_group_size}, 
                got {n_features} from shape {x.shape}, leaving nonzero remainder {remainder}
            """)
        zs = [
            jnp.mean([
                jnp.mean(wavenet(x[...,[i+j]]), axis=-1)
                for j
                in range(0, self.feature_group_size, 1)
            ])
            for i
            in range(0, n_features, self.feature_group_size)
        ]
        y = jnp.concatenate(zs)
        if not self.embed:
            y = ln.Dense(self.dense_size)(y)
            y = jnp.relu(y)
            y = ln.Dense(self.output_size)
            y = jnp.relu(y)
        return y