In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CausalConv1d(nn.Module):
    """
    A 1D causal convolution layer that pads the input on the left.
    """
    def __init__(self, in_channels, out_channels, kernel_size, dilation):
        super(CausalConv1d, self).__init__()
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=0, dilation=dilation)
    
    def forward(self, x):
        # Compute the required padding for causal convolution.
        padding = self.dilation * (self.kernel_size - 1)
        # Pad only on the left (i.e. the beginning of the sequence)
        x = F.pad(x, (padding, 0))
        return self.conv(x)

class WaveNetBlock(nn.Module):
    """
    A single WaveNet residual block with a gated activation unit.
    """
    def __init__(self, residual_channels, skip_channels, kernel_size, dilation):
        super(WaveNetBlock, self).__init__()
        # Two parallel causal convolutions for filter and gate.
        self.filter_conv = CausalConv1d(residual_channels, residual_channels, kernel_size, dilation)
        self.gate_conv   = CausalConv1d(residual_channels, residual_channels, kernel_size, dilation)
        # 1x1 convolutions for residual and skip connections.
        self.residual_conv = nn.Conv1d(residual_channels, residual_channels, kernel_size=1)
        self.skip_conv     = nn.Conv1d(residual_channels, skip_channels, kernel_size=1)
    
    def forward(self, x):
        # Apply dilated convolutions for filter and gate.
        filter_out = torch.tanh(self.filter_conv(x))
        gate_out = torch.sigmoid(self.gate_conv(x))
        # Elementwise multiplication of filter and gate outputs.
        out = filter_out * gate_out
        # Create skip and residual outputs.
        skip = self.skip_conv(out)
        residual = self.residual_conv(out)
        # Residual connection: add the block input back to the output.
        residual = residual + x
        return residual, skip

class WaveNet(nn.Module):
    """
    A simplified WaveNet model that stacks multiple WaveNet blocks.
    """
    def __init__(self, in_channels, residual_channels, skip_channels, end_channels,
                 kernel_size, num_layers, dilation_cycle_length):
        """
        Args:
            in_channels (int): Number of input channels (e.g. number of quantization channels).
            residual_channels (int): Number of channels in the residual layers.
            skip_channels (int): Number of channels in the skip connections.
            end_channels (int): Number of channels in the post-processing layers.
            kernel_size (int): Size of the convolutional kernel.
            num_layers (int): Total number of WaveNet blocks.
            dilation_cycle_length (int): Number of layers before dilation factors repeat.
        """
        super(WaveNet, self).__init__()
        # Initial 1x1 convolution to match channel dimensions.
        self.input_conv = nn.Conv1d(in_channels, residual_channels, kernel_size=1)
        self.blocks = nn.ModuleList()
        # Create a series of WaveNet blocks with exponentially increasing dilations.
        for i in range(num_layers):
            dilation = 2 ** (i % dilation_cycle_length)
            self.blocks.append(WaveNetBlock(residual_channels, skip_channels, kernel_size, dilation))
        self.relu = nn.ReLU()
        # Post-processing: 1x1 convolutions to produce final output logits.
        self.output_conv1 = nn.Conv1d(skip_channels, end_channels, kernel_size=1)
        self.output_conv2 = nn.Conv1d(end_channels, in_channels, kernel_size=1)
    
    def forward(self, x):
        """
        Args:
            x (Tensor): Input tensor of shape (batch, channels, time).
        Returns:
            Tensor: Output logits over the quantized audio channels.
        """
        x = self.input_conv(x)
        skip_connections = []
        # Process the input through each residual block.
        for block in self.blocks:
            x, skip = block(x)
            skip_connections.append(skip)
        # Sum all skip connection outputs.
        out = sum(skip_connections)
        out = self.relu(out)
        out = self.output_conv1(out)
        out = self.relu(out)
        out = self.output_conv2(out)
        return out

In [2]:
 # Example parameters (adjust these as needed)
batch_size = 2
# Suppose audio is quantized into 256 channels (e.g. μ-law quantization)
in_channels = 256  
sequence_length = 16000  # e.g., one second of audio at 16kHz
model = WaveNet(in_channels=in_channels, residual_channels=32, skip_channels=32,
                end_channels=32, kernel_size=2, num_layers=10, dilation_cycle_length=10)
# Create a dummy input tensor (batch, channels, time)
x = torch.randn(batch_size, in_channels, sequence_length)
out = model(x)
print("Output shape:", out.shape)

Output shape: torch.Size([2, 256, 16000])
