In [None]:
import random
import torch
import torch.nn as nn
import torchaudio
from datasets import load_dataset
import soundfile as sf


class WaveNet(nn.Module):
    def __init__(self, residual_channels, skip_channels, dilation_channels, n_layers, n_blocks, input_channels=256):
        super(WaveNet, self).__init__()
        self.residual_channels = residual_channels
        self.skip_channels = skip_channels
        self.dilation_channels = dilation_channels
        self.n_layers = n_layers
        self.n_blocks = n_blocks

        self.dilations = [2 ** i for i in range(n_layers)] * n_blocks

        self.input_conv = nn.Conv1d(input_channels, residual_channels, kernel_size=1)
        self.residual_layers = nn.ModuleList()
        self.skip_layers = nn.ModuleList()

        for dilation in self.dilations:
            self.residual_layers.append(
                nn.Conv1d(residual_channels, dilation_channels, kernel_size=2, dilation=dilation)
            )
            self.skip_layers.append(
                nn.Conv1d(dilation_channels, skip_channels, kernel_size=1)
            )

        self.output_conv1 = nn.Conv1d(skip_channels, skip_channels, kernel_size=1)
        self.output_conv2 = nn.Conv1d(skip_channels, input_channels, kernel_size=1)

    def forward(self, x):
        x = self.input_conv(x)
        skip = 0

        for residual_layer, skip_layer in zip(self.residual_layers, self.skip_layers):
            residual = residual_layer(x)
            skip += skip_layer(residual)
            x = x + residual

        x = torch.relu(skip)
        x = torch.relu(self.output_conv1(x))
        x = self.output_conv2(x)

        return x


# Load dataset
hsn = load_dataset('DBD-research-group/BirdSet', 'HSN')

subset_percentage = 0.05  # 25%
subset_indices = random.sample(range(len(hsn['train'])), int(len(hsn['train']) * subset_percentage))
hsn = hsn['train'].select(subset_indices)
print(len(subset_indices))

def preprocess(batch):
    audio_tensors = []
    audio, _ = torchaudio.load(batch['filepath'][0])
    audio_tensors.append(audio)
    return {'audio': audio_tensors}


hsn = hsn.map(preprocess, batched=True, batch_size=1)
print(hsn)

residual_channels = 32
skip_channels = 512
dilation_channels = 32
n_layers = 10
n_blocks = 3

model = WaveNet(residual_channels, skip_channels, dilation_channels, n_layers, n_blocks)

print(hsn['audio'])
audio_data = hsn['audio'].unsqueeze(0)

model.eval()
with torch.no_grad():
    generated_audio = model(audio_data)

torchaudio.save('generated_bird_sound.wav', generated_audio.squeeze(0), sample_rate=22050)





Downloading builder script:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/13.2k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/146k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/10.8k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.08G [00:00<?, ?B/s]