### imports:

In [22]:
import os
import torchaudio
from torch.utils.data import Dataset
import random
import torch
import torchaudio.transforms as T
import noisereduce as nr
from tqdm import tqdm

### DATASET CLASS:

In [44]:
class AudioDataset(Dataset):
    def __init__(self, root_dir, task="gender", transform=None, feature_extractor=None, n_persons=None, seed=42, 
                 use_preprocessed=False, preprocessed_dir="preprocessed"):
        """
        Args:
            root_dir (str): Path to the directory containing the audio files.
            task (str): The task to perform. Options are "gender" or "owner".
            transform (callable, optional): Optional transform to be applied on a sample.
            feature_extractor (callable, optional): Function to extract features from the waveform.
            n_persons (int, optional): Number of unique persons (studentIDs) to include for owner classification.
            seed (int, optional): Seed for reproducibility when selecting N persons.
            use_preprocessed (bool, optional): Whether to use preprocessed data if available.
            preprocessed_dir (str, optional): Directory containing preprocessed features.
        """
        self.root_dir = root_dir
        self.task = task
        self.transform = transform
        self.feature_extractor = feature_extractor
        self.n_persons = n_persons
        self.seed = seed
        self.use_preprocessed = use_preprocessed
        self.preprocessed_dir = preprocessed_dir

        self.data = []
        self.student_ids = set()

        # Load and parse filenames
        for file in os.listdir(root_dir):
            if file.endswith(".mp3"):
                parts = file[:-4].split("_")  # Remove .mp3 and split by _

                # Validate file naming format
                if len(parts) == 4 and parts[0].startswith("HW") and \
                   (parts[1] == "intro" or parts[1].startswith("Q")) and \
                   parts[3] in ["male", "female"]:

                    try:
                        homework_number = int(parts[0][2:])
                        question_number = None if parts[1] == "intro" else int(parts[1][1:])
                        student_id = parts[2]
                        gender = parts[3]

                        self.data.append({
                            "file_path": os.path.join(root_dir, file),
                            "homework_number": homework_number,
                            "question_number": question_number,
                            "student_id": student_id,
                            "gender": gender
                        })

                        self.student_ids.add(student_id)
                    except ValueError:
                        # Skip files with invalid formats
                        continue

        if task == "owner" and n_persons is not None:
            random.seed(seed)
            selected_ids = random.sample(self.student_ids, n_persons)
            self.data = [item for item in self.data if item["student_id"] in selected_ids]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        if self.use_preprocessed and self.preprocessed_dir:
            preprocessed_path = os.path.join(self.preprocessed_dir, os.path.basename(sample["file_path"]))
            if os.path.exists(preprocessed_path):
                return torchaudio.load(preprocessed_path)

        file_path = sample["file_path"]

        # Load audio
        waveform, sample_rate = torchaudio.load(file_path)

        # Resample to a consistent rate
        target_sample_rate = 16000
        if sample_rate != target_sample_rate:
            resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
            waveform = resampler(waveform)
            sample_rate = target_sample_rate

        # Apply noise reduction
        waveform = nr.reduce_noise(y=waveform.numpy(), sr=sample_rate)
        waveform = torch.tensor(waveform)

        # Normalize waveform
        waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-7)

        # Apply transform if specified
        if self.transform:
            waveform = self.transform(waveform)

        # Extract features if feature_extractor is specified
        if self.feature_extractor:
            features = self.feature_extractor(waveform, sample_rate)
        else:
            features = waveform

        if self.task == "gender":
            label = 0 if sample["gender"] == "male" else 1
        elif self.task == "owner":
            label = sample["student_id"]
        else:
            raise ValueError("Unsupported task. Use 'gender' or 'owner'.")

        return {
            "features": features,
            "sample_rate": sample_rate,
            "label": label,
            "metadata": sample
        }

    def preprocess_and_save(self):
        """
        Preprocess and save features to disk as MP3 files to avoid reloading and reprocessing large datasets.
        """
        os.makedirs(self.preprocessed_dir, exist_ok=True)
        for sample in tqdm(self.data, desc="Processing Audio Files"):
            file_path = sample["file_path"]
            waveform, sample_rate = torchaudio.load(file_path)

            # Resample to a consistent rate
            target_sample_rate = 44100
            if sample_rate != target_sample_rate:
                resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
                waveform = resampler(waveform)
                sample_rate = target_sample_rate

            # # Apply noise reduction

            # # Normalize waveform
            waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-7)

            waveform = nr.reduce_noise(y=waveform.numpy(), sr=target_sample_rate)
            waveform = torch.tensor(waveform)
            # Convert to Pydub AudioSegment for MP3 saving

            # Save preprocessed audio as MP3
            save_path = os.path.join(self.preprocessed_dir, os.path.basename(file_path))
            torchaudio.save(save_path, waveform, target_sample_rate)



In [45]:
myData = AudioDataset("../../HW1_M", preprocessed_dir="../../preprocessed")

In [43]:
myData.preprocess_and_save()

Processing Audio Files:  10%|█         | 69/676 [01:06<07:54,  1.28it/s]