In [26]:
# Cell: Imports & Constants
import os, glob, numpy as np, torch, librosa
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset, DataLoader

# must match your preprocess paths
DEN_DIR = '/home/jovyan/Features/denoised'
EMB_DIR = '/home/jovyan/Features/embeddings'

PANNS_SR      = 32000
N_FFT         = 2048
HOP_LENGTH    = 512
N_MELS        = 128
ALPHA         = 0.5


In [27]:
# Cell: Spectrogram & Augmentation Helpers

def calculate_mel_spectrogram(wave_np, sr=PANNS_SR,
                              n_fft=N_FFT, hop_length=HOP_LENGTH,
                              n_mels=N_MELS):
    mel = librosa.feature.melspectrogram(
        y=wave_np, sr=sr,
        n_fft=n_fft, hop_length=hop_length,
        n_mels=n_mels
    )
    return librosa.power_to_db(mel, ref=np.max)

def create_augmented_mel(log_mel, emb, n_mels=N_MELS, alpha=ALPHA):
    embed_dim, T = emb.shape[0], log_mel.shape[1]
    if embed_dim == n_mels:
        proj = emb
    elif embed_dim > n_mels:
        factor = embed_dim // n_mels
        if embed_dim % n_mels == 0:
            proj = emb.reshape(n_mels, factor).mean(axis=1)
        else:
            proj = emb[:n_mels]
    else:
        proj = np.pad(emb, (0, n_mels - embed_dim))
    tiled = np.tile(proj[:, None], (1, T))
    normed = StandardScaler().fit_transform(tiled.T).T
    return log_mel + alpha * normed


In [29]:
class DenoisedWavDataset(Dataset):
    def __init__(self, den_dir):
        self.paths = sorted(glob.glob(f"{den_dir}/**/*.npz", recursive=True))
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        data = np.load(self.paths[idx])
        wav  = data['waveform']
        lbl  = int(data['label'])
        return torch.from_numpy(wav).float(), torch.tensor(lbl).long()

class EmbeddingDataset(Dataset):
    def __init__(self, emb_dir):
        self.paths = sorted(glob.glob(f"{emb_dir}/**/*_emb.npz", recursive=True))
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        data = np.load(self.paths[idx])
        emb  = data['embedding']
        lbl  = int(data['label'])
        return torch.from_numpy(emb).float(), torch.tensor(lbl).long()

class MelDataset(Dataset):
    def __init__(self, den_dir):
        self.paths = sorted(glob.glob(f"{den_dir}/**/*.npz", recursive=True))
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        data = np.load(self.paths[idx])
        wav  = data['waveform']
        lbl  = int(data['label'])
        mel  = calculate_mel_spectrogram(wav)
        return torch.from_numpy(mel).float(), torch.tensor(lbl).long()

class AugmentedMelDataset(Dataset):
    def __init__(self, den_dir, emb_dir):
        self.den_paths = sorted(glob.glob(f"{den_dir}/**/*.npz", recursive=True))
        self.emb_dir   = emb_dir
    def __len__(self):
        return len(self.den_paths)
    def __getitem__(self, idx):
        dpath = self.den_paths[idx]
        data  = np.load(dpath)
        wav   = data['waveform']
        lbl   = int(data['label'])
        # corresponding embedding file
        rel   = os.path.relpath(dpath, DEN_DIR)
        emb_p = os.path.join(self.emb_dir, rel.replace('.npz','_emb.npz'))
        emb   = np.load(emb_p)['embedding']
        mel   = calculate_mel_spectrogram(wav)
        aug   = create_augmented_mel(mel, emb)
        return torch.from_numpy(aug).float(), torch.tensor(lbl).long()

In [37]:
wav_loader    = DataLoader(DenoisedWavDataset(DEN_DIR),
                           batch_size=256, shuffle=True,
                           num_workers=0, pin_memory=True)

emb_loader    = DataLoader(EmbeddingDataset(EMB_DIR),
                           batch_size=256, shuffle=True,
                           num_workers=0, pin_memory=True)

mel_loader    = DataLoader(MelDataset(DEN_DIR),
                           batch_size=256, shuffle=True,
                           num_workers=0, pin_memory=True)

augmel_loader = DataLoader(AugmentedMelDataset(DEN_DIR, EMB_DIR),
                           batch_size=256, shuffle=True,
                           num_workers=0, pin_memory=True)



In [38]:
# ───────── Sanity Check ─────────
for name, loader in [
    ("WAV", wav_loader),
    ("EMB", emb_loader),
    ("MEL", mel_loader),
    ("AUGMEL", augmel_loader)
]:
    x, y = next(iter(loader))
    print(f"{name} batch: {x.shape}, labels: {y.shape}")

WAV batch: torch.Size([256, 320000]), labels: torch.Size([256])
EMB batch: torch.Size([256, 2048]), labels: torch.Size([256])
MEL batch: torch.Size([256, 128, 626]), labels: torch.Size([256])
AUGMEL batch: torch.Size([256, 128, 626]), labels: torch.Size([256])
