In [None]:
import os
import torch
import torchaudio
from torchaudio.transforms import MFCC

# Parameters for MFCC extraction and frame standardization
sample_rate = 16000
n_mfcc = 20
melkwargs = {
    "n_fft": 400,       # frame size of 25ms
    "hop_length": 160,  # hop size of 10ms
    "n_mels": 40        # number of Mel filterbanks
}
max_frames = 250  # Number of time frames to pad/trim to (e.g., ~2s of audio)

# Initialize the MFCC transform
mfcc_transform = MFCC(
    sample_rate=sample_rate,
    n_mfcc=n_mfcc,
    melkwargs=melkwargs
)

class LibriSpeechMFCC(torch.utils.data.Dataset):
    """
    PyTorch Dataset that wraps LibriSpeech and applies MFCC transform,
    with optional padding/truncation and normalization.
    Returns:
        mfcc: Tensor of shape (n_mfcc, max_frames)
        speaker_id: int
    """
    def __init__(self, root="./data", url="train-clean-100", download=False,
                 transform=None, max_frames=250):
        self.dataset = torchaudio.datasets.LIBRISPEECH(root, url=url, download=download)
        self.transform = transform
        self.max_frames = max_frames

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        waveform, sr, _, speaker_id, _, _ = self.dataset[idx]
        # Resample if needed
        if sr != sample_rate:
            waveform = torchaudio.functional.resample(waveform, sr, sample_rate)
        # Convert to mono if needed
        if waveform.size(0) > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        # Apply MFCC transform: output shape (1, n_mfcc, time_frames)
        mfcc = self.transform(waveform).squeeze(0)  # now (n_mfcc, time_frames)

        # 1) Pad or truncate to fixed number of frames
        t = mfcc.shape[1]
        if t < self.max_frames:
            pad = torch.zeros((n_mfcc, self.max_frames - t))
            mfcc = torch.cat((mfcc, pad), dim=1)
        else:
            mfcc = mfcc[:, :self.max_frames]

        # 2) Normalize per coefficient (mean=0, std=1)
        mean = mfcc.mean(dim=1, keepdim=True)
        std = mfcc.std(dim=1, keepdim=True) + 1e-6
        mfcc = (mfcc - mean) / std

        return mfcc, speaker_id

# Usage example
if __name__ == "__main__":
    # Ensure data directory exists and dataset is downloaded
    os.makedirs("./data", exist_ok=True)
    dataset = LibriSpeechMFCC(
        root="./data",
        url="train-clean-100",
        download=not os.path.isdir("./data/LibriSpeech"),
        transform=mfcc_transform,
        max_frames=max_frames
    )

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=16,
        shuffle=True,
        num_workers=2
    )

    # Iterate one batch to check shapes
    for batch_mfcc, batch_speaker in dataloader:
        print("MFCC batch shape:", batch_mfcc.shape)  # (batch, n_mfcc, max_frames)
        print("Speaker IDs:", batch_speaker)
        break
