In [None]:
import torch
import torch.nn as nn

In [None]:
class WaveNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation):
        super(WaveNetBlock, self).__init__()
        self.dilation = dilation
        self.conv_filter = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation)
        self.conv_gate = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation)
        self.conv_res = nn.Conv1d(out_channels, in_channels, 1)  # Residual connection
        self.conv_skip = nn.Conv1d(out_channels, out_channels, 1)  # Skip connection

    def forward(self, x):
        # Apply dilated convolutions
        filter_output = torch.tanh(self.conv_filter(x))
        gate_output = torch.sigmoid(self.conv_gate(x))
        gated_output = filter_output * gate_output
        
        # Residual and skip connections
        residual = self.conv_res(gated_output)
        skip_connection = self.conv_skip(gated_output)
        output = x + residual
        
        return output, skip_connection

In [None]:
class WaveNet(nn.Module):
    def __init__(self, num_blocks, num_layers, in_channels, out_channels, residual_channels, skip_channels, kernel_size):
        super(WaveNet, self).__init__()
        self.num_blocks = num_blocks
        self.num_layers = num_layers
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.residual_channels = residual_channels
        self.skip_channels = skip_channels
        self.kernel_size = kernel_size
        
        self.start_conv = nn.Conv1d(in_channels, residual_channels, 1)
        self.blocks = nn.ModuleList([
            WaveNetBlock(residual_channels, residual_channels, kernel_size, 2 ** i)
            for i in range(num_layers)
        ])
        self.end_conv1 = nn.Conv1d(skip_channels, skip_channels, 1)
        self.end_conv2 = nn.Conv1d(skip_channels, out_channels, 1)
        
    def forward(self, x):
        x = self.start_conv(x)
        skip_connections = []

        for _ in range(self.num_blocks):
            for layer in self.blocks:
                x, skip = layer(x)
                skip_connections.append(skip)

        # Sum all skip connections
        combined_skip = sum(skip_connections)
        output = torch.relu(combined_skip)
        output = self.end_conv1(output)
        output = torch.relu(output)
        output = self.end_conv2(output)
        
        return output

In [None]:
def EnergyConservingLoss(input_mix, input_voice, input_noise, generated_voice):
    
    voice_diff = abs(input_voice - generated_voice)
    noise_diff = abs(input_noise - (input_mix - generated_voice))
    
    loss = voice_diff + noise_diff
    
    return loss

In [None]:
# Hyperparameters
num_blocks = 3
num_layers = 10
in_channels = 1
out_channels = 1
residual_channels = 64
skip_channels = 256
kernel_size = 2

# Create the WaveNet model
wavenet_model = WaveNet(num_blocks, num_layers, in_channels, out_channels, residual_channels, skip_channels, kernel_size)