In [None]:
!pip install torch torchaudio tqdm matplotlib > /dev/null


In [None]:
import os, random, torch, torch.nn as nn, torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import Audio, display
import zipfile


In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
# =========================================================
# 0Ô∏è‚É£ File Extraction
# =========================================================
ZIP_PATH = "/content/drive/MyDrive/the-frequency-quest.zip"
EXTRACT_ROOT = "/content/data"
BASE_PATH = None
print(f"Attempting to extract {ZIP_PATH} to {EXTRACT_ROOT}...")
if os.path.exists(ZIP_PATH):
    os.makedirs(EXTRACT_ROOT, exist_ok=True)
    try:
        with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
            zip_ref.extractall(EXTRACT_ROOT)
        print("‚úÖ Extraction complete.")
    except Exception as e:
        print(f"‚ùå FATAL ERROR during extraction: {e}")
else:
    print(f"‚ùå FATAL ERROR: Zip file not found at {ZIP_PATH}. Check the path.")


In [None]:
# =========================================================
# 1Ô∏è‚É£ Dataset class
# =========================================================
class TrainAudioSpectrogramDataset(Dataset):
    def __init__(self, root_dir, categories, max_frames=512, fraction=1.0, compute_stats=True):
        self.root_dir = root_dir
        self.categories = categories
        self.max_frames = max_frames
        self.file_list = []
        self.class_to_idx = {cat: i for i, cat in enumerate(categories)}

        for cat_name in self.categories:
            cat_dir = os.path.join(root_dir, cat_name)
            if not os.path.isdir(cat_dir):
                continue

            files_in_cat = [os.path.join(cat_dir, f) for f in os.listdir(cat_dir) if f.lower().endswith(".wav")]
            num_to_sample = int(len(files_in_cat) * fraction)
            if num_to_sample == 0 and len(files_in_cat) > 0:
                num_to_sample = 1
            sampled_files = random.sample(files_in_cat, num_to_sample) if files_in_cat else []
            label_idx = self.class_to_idx[cat_name]
            self.file_list.extend([(file_path, label_idx) for file_path in sampled_files])

        self.n_mels = 128
        self.n_fft = 1024
        self.hop_length = 128

        # Compute dataset mean/std for normalization
        self.mean = 0.0
        self.std = 1.0
        if compute_stats and len(self.file_list) > 0:
            sums, sums_sq, count = 0.0, 0.0, 0
            for path, _ in tqdm(self.file_list, desc="Computing dataset stats", leave=False):
                wav, sr = torchaudio.load(path)
                if wav.size(0) > 1:
                    wav = wav.mean(dim=0, keepdim=True)
                mel_spec = torchaudio.transforms.MelSpectrogram(
                    sample_rate=sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels
                )(wav)
                log_spec = torch.log1p(mel_spec)
                _, _, n_frames = log_spec.shape
                if n_frames < self.max_frames:
                    log_spec = F.pad(log_spec, (0, self.max_frames - n_frames))
                else:
                    log_spec = log_spec[:, :, :self.max_frames]

                # üî• FIX 1: Use reshape instead of view
                v = log_spec.reshape(-1)

                sums += v.sum().item()
                sums_sq += (v * v).sum().item()
                count += v.numel()
            self.mean = sums / count
            self.std = (sums_sq / count - (self.mean ** 2)) ** 0.5
            print(f"Dataset mean/std for log-mel: {self.mean:.6f} / {self.std:.6f}")

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        path, label = self.file_list[idx]
        wav, sr = torchaudio.load(path)
        if wav.size(0) > 1:
            wav = wav.mean(dim=0, keepdim=True)
        mel_spec = torchaudio.transforms.MelSpectrogram(
            sample_rate=sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels
        )(wav)
        log_spec = torch.log1p(mel_spec)
        _, _, n_frames = log_spec.shape
        if n_frames < self.max_frames:
            log_spec = F.pad(log_spec, (0, self.max_frames - n_frames))
        else:
            log_spec = log_spec[:, :, :self.max_frames]

        # üî• FIX 2: Use reshape instead of view (for consistency)
        log_spec = (log_spec.reshape(1, self.n_mels, self.max_frames) - self.mean) / (self.std + 1e-9)

        label_vec = F.one_hot(torch.tensor(label), num_classes=len(self.categories)).float()
        return log_spec, label_vec


In [None]:
# =========================================================
# 2Ô∏è‚É£ Generator and Discriminator - generator part
# =========================================================
from torch.nn.utils import spectral_norm

class CGAN_Generator(nn.Module):
    def __init__(self, latent_dim, num_categories):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_categories = num_categories
        self.fc = nn.Linear(latent_dim + num_categories, 256 * 8 * 32)
        self.unflatten_shape = (256, 8, 32)
        self.net = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.BatchNorm2d(32), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, 1, 4, 2, 1), nn.Tanh()
        )
    def forward(self, z, y):
        h = torch.cat([z, y], dim=1)
        h = self.fc(h).view(-1, *self.unflatten_shape)
        return self.net(h)


In [None]:
# =========================================================
# 2Ô∏è‚É£ Generator and Discriminator - discriminator part
# =========================================================
class CGAN_Discriminator(nn.Module):
    def __init__(self, num_categories):
        super().__init__()
        self.num_categories = num_categories
        self.label_embedding = nn.Linear(num_categories, 128 * 512)
        self.model = nn.Sequential(
            spectral_norm(nn.Conv2d(2, 64, 4, 2, 1)), nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(64, 128, 4, 2, 1)), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(128, 256, 4, 2, 1)), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(256, 512, 4, 2, 1)), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True),
        )
        self.final_fc = None
    def forward(self, x, labels):
        label_map = self.label_embedding(labels).view(-1, 1, 128, 512)
        d_in = torch.cat([x, label_map], dim=1)
        out = self.model(d_in)
        out = out.view(out.size(0), -1)
        if self.final_fc is None:
            self.final_fc = spectral_norm(nn.Linear(out.size(1), 1)).to(out.device)
        return self.final_fc(out)


In [None]:
# =========================================================
# 3Ô∏è‚É£ Audio generation helper (n_iter = 64 for better audio)
# =========================================================
def generate_audio_gan(generator, category_idx, device, dataset_mean, dataset_std,
                       sample_rate=22050, n_fft=1024, hop_length=128, n_mels=128):
    generator.eval()
    y = F.one_hot(torch.tensor([category_idx]), num_classes=generator.num_categories).float().to(device)
    z = torch.randn(1, generator.latent_dim, device=device)
    with torch.no_grad():
        gen_norm_log = generator(z, y).squeeze(1).cpu()
    gen_log = gen_norm_log * (dataset_std + 1e-9) + dataset_mean
    mel_spec = torch.expm1(gen_log).clamp(min=1e-6)

    inverse_mel = torchaudio.transforms.InverseMelScale(
        n_stft=n_fft // 2 + 1, n_mels=n_mels, sample_rate=sample_rate,
        driver="gelsy"
    )
    griffin = torchaudio.transforms.GriffinLim(
        n_fft=n_fft, hop_length=hop_length, win_length=n_fft, n_iter=64
    )
    linear_spec = inverse_mel(mel_spec)
    wav = griffin(linear_spec).cpu().clamp(-1, 1)

    return wav.unsqueeze(0)

# =========================================================
# 4Ô∏è‚É£ Save and Play helper
# =========================================================
def save_and_play(wav, sample_rate, filename):
    os.makedirs(os.path.dirname(filename) or ".", exist_ok=True)

    if wav.dim() > 2:
        wav_to_save = wav.squeeze(1)
    else:
        wav_to_save = wav

    torchaudio.save(filename, wav_to_save, sample_rate)
    print(f"‚úÖ Saved: {filename}")

    display(Audio(wav_to_save.squeeze(0).numpy(), rate=sample_rate))


In [None]:
# =========================================================
# 5Ô∏è‚É£ Training (LSGAN)
# =========================================================
def train_gan(generator, discriminator, dataloader, device, categories, epochs, lr, latent_dim, dataset_mean, dataset_std):
    optG = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    optD = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    os.makedirs("gan_generated_audio", exist_ok=True)

    hop_length = dataloader.dataset.hop_length

    for epoch in range(1, epochs + 1):
        loop = tqdm(dataloader, desc=f"Epoch {epoch}/{epochs}")
        for real_specs, labels in loop:
            real_specs, labels = real_specs.to(device), labels.to(device)
            bs = real_specs.size(0)
            # --- D step ---
            optD.zero_grad()
            real_preds = discriminator(real_specs, labels)
            z_G = torch.randn(bs, latent_dim, device=device)
            fake_specs = generator(z_G, labels)
            fake_preds = discriminator(fake_specs.detach(), labels)
            lossD = 0.5 * ((real_preds - 1)**2 + (fake_preds)**2).mean()
            lossD.backward(); optD.step()
            # --- G step ---
            optG.zero_grad()
            z_D = torch.randn(bs, latent_dim, device=device)
            fake_specs = generator(z_D, labels)
            fake_preds = discriminator(fake_specs, labels)
            lossG = 0.5 * ((fake_preds - 1)**2).mean()
            lossG.backward(); optG.step()
            loop.set_postfix(lossD=lossD.item(), lossG=lossG.item())

        if epoch % 10 == 0:
            print(f"
Generating audio for epoch {epoch}...")
            for i, cat in enumerate(categories[:3]):
                wav = generate_audio_gan(generator, i, device, dataset_mean, dataset_std, hop_length=hop_length)
                save_and_play(wav, 22050, f"gan_generated_audio/{cat}_ep{epoch}.wav")


In [None]:
# =========================================================
# 6Ô∏è‚É£ Run training
# =========================================================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LATENT_DIM, EPOCHS, BATCH_SIZE, LR = 100, 300, 32, 2e-4

# Correct category names (adjust if needed)
train_categories = ['dog_bark', 'drilling', 'engine_idling', 'siren', 'street_music']

# --- AUTOMATIC BASE PATH FINDER ---
def find_category_base_path(root_dir, categories):
    print(f"
Searching for category folders within {root_dir}...")
    for dirpath, dirnames, filenames in os.walk(root_dir):
        if all(cat in dirnames for cat in categories):
            return dirpath
    return None

if os.path.isdir(EXTRACT_ROOT):
    BASE_PATH = find_category_base_path(EXTRACT_ROOT, train_categories)

# --- Start Training ---
if BASE_PATH is None:
    print("
FATAL ERROR: Could not find the base directory containing all category folders.")
    print(f"Please manually inspect the contents of {EXTRACT_ROOT} to determine the final path.")
else:
    print(f"‚úÖ Automatically determined BASE_PATH: {BASE_PATH}")
    print("Categories:", train_categories)

    train_dataset = TrainAudioSpectrogramDataset(BASE_PATH, train_categories)

    if len(train_dataset) == 0:
        print("
FATAL ERROR: Dataset is empty. Check that the extracted folders contain .wav files.")
    else:
        print(f"Dataset loaded successfully with {len(train_dataset)} samples.")
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
        G = CGAN_Generator(LATENT_DIM, len(train_categories)).to(DEVICE)
        D = CGAN_Discriminator(len(train_categories)).to(DEVICE)
        print("Generator output shape:", G(torch.randn(1, LATENT_DIM).to(DEVICE),
                                             F.one_hot(torch.tensor([0]), num_classes=len(train_categories)).float().to(DEVICE)).shape)

        train_gan(G, D, train_loader, DEVICE, train_categories, EPOCHS, LR, LATENT_DIM, train_dataset.mean, train_dataset.std)
