In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")

In [4]:
# ==============================================================================
# 1. DATASET
# ==============================================================================

class TrainAudioSpectrogramDataset(Dataset):
    """
    This dataset class is used ONCE to pre-process all our audio
    into spectrogram tensors.
    """
    def __init__(self, root_dir, categories, max_frames=512, fraction=1.0, sample_rate=22050):
        self.root_dir = root_dir
        self.categories = categories
        self.max_frames = max_frames
        self.file_list = []
        self.class_to_idx = {cat: i for i, cat in enumerate(categories)}
        self.sample_rate = sample_rate

        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.sample_rate, n_fft=1024, hop_length=256, n_mels=128
        ).to("cpu")

        print("Populating file list...")
        for i, category in enumerate(self.categories):
            path = os.path.join(self.root_dir, category)
            if not os.path.isdir(path):
                print(f"Warning: Directory not found: {path}")
                continue

            files = [os.path.join(path, f) for f in os.listdir(path) if f.endswith(('.wav', '.mp3', '.flac'))]
            random.shuffle(files)
            files = files[:int(len(files) * fraction)]
            for f in files:
                self.file_list.append((f, i))
        print(f"Found {len(self.file_list)} audio files.")


    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        path, label = self.file_list[idx]

        try:
            wav, sr = torchaudio.load(path)
        except Exception as e:
            print(f"\nError loading file {path}: {e}")
            return torch.zeros(1, 128, self.max_frames), F.one_hot(torch.tensor(0), num_classes=len(self.categories)).float()

        if sr != self.sample_rate:
            wav = torchaudio.functional.resample(wav, sr, self.sample_rate)

        if wav.size(0) > 1:
            wav = wav.mean(dim=0, keepdim=True)

        if wav.shape[1] < 1024:
             wav = F.pad(wav, (0, 1024 - wav.shape[1]))

        with torch.no_grad():
            mel_spec = self.mel_transform(wav)

        log_spec = torch.log1p(mel_spec)

        _, _, n_frames = log_spec.shape
        if n_frames < self.max_frames:
            pad = self.max_frames - n_frames
            log_spec = F.pad(log_spec, (0, pad))
        else:
            log_spec = log_spec[:, :, :self.max_frames]

        label_vec = F.one_hot(torch.tensor(label), num_classes=len(self.categories)).float()
        return log_spec, label_vec

In [5]:
# ==============================================================================
# 5. PRE-COMPUTATION SCRIPT (RUN THIS ONCE)
# ==============================================================================
def preprocess_and_save():
    print("Starting dataset pre-computation...")

    BASE_PATH = "/content/drive/MyDrive/decibel_duel"
    TRAIN_PATH = os.path.join(BASE_PATH, 'train/train')
    PRECOMPUTED_PATH = os.path.join(BASE_PATH, 'train/precompute')

    NEW_MAX_FRAMES = 256
    SAMPLE_RATE = 22050
    train_categories = sorted([d for d in os.listdir(TRAIN_PATH) if os.path.isdir(os.path.join(TRAIN_PATH, d))])

    print(f"Output directory: {PRECOMPUTED_PATH}")
    print(f"Categories: {train_categories}")
    print(f"Max Frames: {NEW_MAX_FRAMES}")

    # 1. Initialize the *original* dataset class
    original_dataset = TrainAudioSpectrogramDataset(
        root_dir=TRAIN_PATH,
        categories=train_categories,
        max_frames=NEW_MAX_FRAMES,
        sample_rate=SAMPLE_RATE
    )

    # 2. Create a simple dataloader to iterate
    loader = DataLoader(original_dataset, batch_size=1, shuffle=False, num_workers=2)

    # 3. Loop, create save paths, and save
    print(f"Processing {len(original_dataset)} files...")
    for i, (spec, label) in enumerate(tqdm(loader, desc="Pre-processing audio")):
        label_idx = label.argmax().item()
        category_name = train_categories[label_idx]

        save_dir = os.path.join(PRECOMPUTED_PATH, category_name)
        os.makedirs(save_dir, exist_ok=True)

        save_path = os.path.join(save_dir, f"spec_{i:06d}.pt")

        # Squeeze to remove batch_size=1
        torch.save((spec.squeeze(0), label.squeeze(0)), save_path)

    print("--- Pre-computation Complete! ---")
    print(f"All spectrograms saved to {PRECOMPUTED_PATH}")

# --- Uncomment and run the line below ONCE ---
#preprocess_and_save()