In [None]:
import torch
import torchaudio
import os
import torch
import numpy as np
# dep: git clone https://github.com/NVIDIA/waveglow
from waveglow import glow
import os
from scipy.io.wavfile import write

# test with a pre-trained model
class BirdSoundDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, sr=22050, n_mels=80, hop_length=256):
        self.data_dir = data_dir
        self.wav_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.wav')]
        self.sr = sr
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sr, n_mels=n_mels, hop_length=hop_length
        )

    def __len__(self):
        return len(self.wav_files)

    def __getitem__(self, idx):
        wav_file = self.wav_files[idx]
        waveform, _ = torchaudio.load(wav_file)
        mel_spectrogram = self.mel_transform(waveform)
        return mel_spectrogram

data_dir = "D:/data"
dataset = BirdSoundDataset(data_dir)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)

waveglow = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_waveglow')
waveglow.eval()

for m in waveglow.modules():
    if 'Conv' in str(type(m)):
        torch.nn.utils.remove_weight_norm(m)


def generate_audio(mel_spectrogram, waveglow_model, output_dir, filename='generated_sample.wav'):
    mel_spectrogram = mel_spectrogram.cuda()
    waveglow_model = waveglow_model.cuda()

    with torch.no_grad():
        audio = waveglow_model.infer(mel_spectrogram, sigma=0.6)

    audio = audio.cpu().numpy().astype(np.float32)
    audio = audio.squeeze()
    output_path = os.path.join(output_dir, filename)
    write(output_path, 22050, audio)

for idx, mel_spec in enumerate(dataloader):
    generate_audio(mel_spec, waveglow, output_dir="./gen_sample", filename=f'bird_sound_{idx}.wav')
    break
