In [16]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
import torch
from torch.utils.data import Dataset
import torchaudio
import torchaudio.transforms as T
from glob import glob

In [12]:
cmgan_checkpoint = torch.load('./Checkpoint_CMGAN/CMGAN_epoch_34_0.124', weights_only=False, map_location=torch.device('cpu'))

In [14]:
class AudioDataset(Dataset):
    def __init__(self, data_dir, sample_rate=16000, cmgan_model=None):
        data_files = glob(os.path.join(data_dir, "*.wav"))[:500]  # Limit to first 500 files

        self.sample_rate = sample_rate
        self.cmgan = cmgan_model

        self.spectrogram_transform = T.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=1024,
            hop_length=256,
            n_mels=128
        )

        self.resampler = T.Resample(orig_freq=44100, new_freq=sample_rate)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path = self.files[idx]

        waveform, sr = torchaudio.load(path)

        if sr != self.sample_rate:
            waveform = self.resampler(waveform)

        if self.cmgan:
            waveform = self.cmgan(waveform)

        spectrogram = self.spectrogram_transform(waveform)

        return spectrogram, path

In [19]:
dataset = AudioDataset("Dataset/noisy_trainset_56spk_wav/noisy_trainset_56spk_wav/", cmgan_model=cmgan_checkpoint)