### imports:

In [1]:
import os
import torchaudio
from torch.utils.data import Dataset
import random

### DATASET CLASS:

In [None]:
class AudioDataset(Dataset):
    def __init__(self, root_dir, task="gender", transform=None, n_persons=None, seed=42):
        """
        Args:
            root_dir (str): Path to the directory containing the audio files.
            task (str): The task to perform. Options are "gender" or "owner".
            transform (callable, optional): Optional transform to be applied on a sample.
            n_persons (int, optional): Number of unique persons (studentIDs) to include for owner classification.
            seed (int, optional): Seed for reproducibility when selecting N persons.
        """
        self.root_dir = root_dir
        self.task = task
        self.transform = transform
        self.n_persons = n_persons
        self.seed = seed

        self.data = []
        self.student_ids = set()

        # Load and parse filenames
        for file in os.listdir(root_dir):
            if file.endswith(".mp3"):
                parts = file[:-4].split("_")  # Remove .mp3 and split by _

                # Validate file naming format
                if len(parts) == 4 and parts[0].startswith("HW") and \
                   (parts[1] == "intro" or parts[1].startswith("Q")) and \
                   parts[3] in ["male", "female"]:

                    try:
                        homework_number = int(parts[0][2:])
                        question_number = None if parts[1] == "intro" else int(parts[1][1:])
                        student_id = parts[2]
                        gender = parts[3]

                        self.data.append({
                            "file_path": os.path.join(root_dir, file),
                            "homework_number": homework_number,
                            "question_number": question_number,
                            "student_id": student_id,
                            "gender": gender
                        })

                        self.student_ids.add(student_id)
                    except ValueError:
                        # Skip files with invalid formats
                        continue

        if task == "owner" and n_persons is not None:
            random.seed(seed)
            selected_ids = random.sample(self.student_ids, n_persons)
            self.data = [item for item in self.data if item["student_id"] in selected_ids]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        file_path = sample["file_path"]

        # Load audio
        waveform, sample_rate = torchaudio.load(file_path)

        # Apply transform if specified
        if self.transform:
            waveform = self.transform(waveform)

        if self.task == "gender":
            label = 0 if sample["gender"] == "male" else 1
        elif self.task == "owner":
            label = sample["student_id"]
        else:
            raise ValueError("Unsupported task. Use 'gender' or 'owner'.")

        return {
            "waveform": waveform,
            "sample_rate": sample_rate,
            "label": label,
            "metadata": sample
        }