<a href="https://colab.research.google.com/github/mallhartotey2903-png/Decibal-Duel-mallhar-250041026/blob/main/Copy_of_GAN_audio.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchaudio torchvision transformers



In [None]:
# ==============================================================================
# 0. IMPORTS & INITIAL SETUP
# ==============================================================================
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import Audio, display
import math
import json
from pathlib import Path
torch.manual_seed(42)
random.seed(42)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
class TrainAudioSpectrogramDataset(Dataset):
    """
    Loads WAVs from category subfolders, returns log-mel-spectrograms
    normalized to [-1, 1] and one-hot labels.
    """
    def __init__(self, root_dir, categories, max_frames=512, fraction=1.0, sample_rate=22050):
        self.root_dir = Path(root_dir)
        self.categories = list(categories)
        self.max_frames = max_frames
        self.sample_rate = sample_rate
        self.file_list = []
        self.class_to_idx = {cat: i for i, cat in enumerate(self.categories)}

        for cat_name in self.categories:
            cat_dir = self.root_dir / cat_name
            if not cat_dir.exists() or not cat_dir.is_dir():
                continue
            files_in_cat = sorted([str(p) for p in cat_dir.glob("*.wav")])
            if len(files_in_cat) == 0:
                continue
            num_to_sample = max(1, int(len(files_in_cat) * fraction))
            num_to_sample = min(num_to_sample, len(files_in_cat))
            sampled_files = random.sample(files_in_cat, num_to_sample)
            label_idx = self.class_to_idx[cat_name]
            self.file_list.extend([(file_path, label_idx) for file_path in sampled_files])

        if len(self.file_list) == 0:
            raise RuntimeError(f"No wav files found in {root_dir} for categories {categories}")

        # Pre-define mel transform to avoid recreating each call
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.sample_rate, n_fft=1024, hop_length=256, n_mels=128
        )

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        path, label = self.file_list[idx]
        wav, sr = torchaudio.load(path)
        if sr != self.sample_rate:
            # Resample if needed
            wav = torchaudio.transforms.Resample(sr, self.sample_rate)(wav)
            sr = self.sample_rate

        if wav.size(0) > 1:
            wav = wav.mean(dim=0, keepdim=True)  # mono

        mel_spec = self.mel_transform(wav)  # (1, n_mels, frames)
        log_spec = torch.log1p(mel_spec)    # non-negative

        # pad / crop in time axis (last dim)
        _, _, n_frames = log_spec.shape
        if n_frames < self.max_frames:
            pad = self.max_frames - n_frames
            log_spec = F.pad(log_spec, (0, pad))
        else:
            log_spec = log_spec[:, :, :self.max_frames]

        # Normalize to [-1, 1] using a running heuristic (you can compute dataset stats if desired)
        # We use a simple per-sample normalization: divide by (max + eps)
        max_val = log_spec.max()
        log_spec = log_spec / (max_val + 1e-8)     # now in [0,1]
        log_spec = 2 * log_spec - 1                # now in [-1,1]

        label_vec = F.one_hot(torch.tensor(label), num_classes=len(self.categories)).float()
        return log_spec, label_vec


# ==============================================================================
# 2. GAN MODEL DEFINITIONS (GENERATOR & DISCRIMINATOR)
# ==============================================================================

class CGAN_Generator(nn.Module):
    def __init__(self, latent_dim, num_categories, spec_shape=(128, 512)):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_categories = num_categories
        self.spec_shape = spec_shape

        self.fc = nn.Linear(latent_dim + num_categories, 256 * 8 * 32)
        self.unflatten_shape = (256, 8, 32)

        self.net = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # 16x64
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 32x128
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),   # 64x256
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),    # 128x512
            nn.Tanh()  # match the dataset normalization [-1,1]
        )

    def forward(self, z, y):
        # z: (B, latent_dim), y: (B, num_categories)
        h = torch.cat([z, y], dim=1)
        h = self.fc(h)
        h = h.view(-1, *self.unflatten_shape)
        fake_spec = self.net(h)  # shape (B, 1, H, W)
        return fake_spec

class CGAN_Discriminator(nn.Module):
    def __init__(self, num_categories, spec_shape=(128, 512), use_spectral_norm=False):
        super().__init__()
        self.num_categories = num_categories
        self.spec_shape = spec_shape
        H, W = spec_shape

        self.label_embedding = nn.Linear(num_categories, H * W)

        conv2d = nn.utils.spectral_norm if use_spectral_norm else (lambda x: x)

        self.net = nn.Sequential(
            conv2d(nn.Conv2d(2, 32, kernel_size=4, stride=2, padding=1)), # 64x256
            nn.LeakyReLU(0.2, inplace=True),

            conv2d(nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)), # 32x128
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            conv2d(nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)),# 16x64
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            conv2d(nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)),# 8x32
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 1, kernel_size=(8, 32), stride=1, padding=0) # -> 1x1
        )

    def forward(self, spec, y):
        # spec: (B,1,H,W) in [-1,1]
        label_map = self.label_embedding(y).view(-1, 1, *self.spec_shape)
        h = torch.cat([spec, label_map], dim=1)
        logit = self.net(h)
        return logit.view(-1, 1)

# ==============================================================================
# 3. UTILITY FUNCTIONS (GENERATION, SAVING)
# ==============================================================================

def generate_audio_gan(generator, category_idx, num_samples, device, sample_rate=22050):
    generator.eval()
    num_categories = generator.num_categories
    latent_dim = generator.latent_dim

    # Prepare label and noise (make labels match batch size)
    y = F.one_hot(torch.tensor([category_idx]), num_classes=num_categories).float().to(device)
    y = y.repeat(num_samples, 1)  # broadcast labels
    z = torch.randn(num_samples, latent_dim, device=device)

    with torch.no_grad():
        log_spec_gen = generator(z, y)  # (B,1,H,W) in [-1,1]

    # Undo normalization from dataset: [-1,1] -> [0,1]
    spec_gen = (log_spec_gen + 1.0) / 2.0

    # Convert from 'log1p' domain: we must invert log1p => expm1, but here
    # we normalized per-sample using max division; the true inverse is approximate.
    # A more faithful pipeline: save generated log-mel (unnormalized) from generator.
    # We'll assume generator roughly produces values in scaled-log space:
    spec_gen = torch.expm1(spec_gen * 10.0)  # heuristic scale multiplier (tunable)

    spec_gen = spec_gen.squeeze(1).to(device)  # (B, n_mels, frames)

    inverse_mel = torchaudio.transforms.InverseMelScale(
        n_stft=1024 // 2 + 1, n_mels=128, sample_rate=sample_rate
    ).to(device)

    linear_spec = inverse_mel(spec_gen)  # (B, n_fft_bins, frames)

    griffin = torchaudio.transforms.GriffinLim(
        n_fft=1024, hop_length=256, win_length=1024, n_iter=32
    ).to(device)

    waveform = griffin(linear_spec)  # (B, samples)
    return waveform.cpu()  # return CPU tensor

def save_and_play(wav, sample_rate, filename):
    """
    wav: Tensor of shape (samples,) or (1, samples) or (B, samples) with B=1
    """
    if isinstance(wav, torch.Tensor):
        wav_t = wav.detach()
    else:
        wav_t = torch.tensor(wav)

    # If batch dim exists, take first
    if wav_t.dim() == 2 and wav_t.size(0) > 1:
        wav_t = wav_t[0]

    if wav_t.dim() == 2 and wav_t.size(0) == 1:
        wav_t = wav_t.squeeze(0)

    # Ensure shape (channels, samples) for torchaudio.save; make mono
    if wav_t.dim() == 1:
        wav_t = wav_t.unsqueeze(0)

    # convert to float32
    wav_t = wav_t.float()

    torchaudio.save(filename, wav_t, sample_rate=sample_rate)
    print(f"Saved to {filename}")
    display(Audio(data=wav_t.numpy(), rate=sample_rate))

# ==============================================================================
# 4. GAN TRAINING FUNCTION
# ==============================================================================

def train_gan(generator, discriminator, dataloader, device, categories, epochs, lr, latent_dim):
    # Create checkpoint directory at the start of training
    os.makedirs("checkpoints", exist_ok=True)

    # Optimizers for each model
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

    # Loss function
    criterion = nn.BCEWithLogitsLoss()

    # Create directories for output
    os.makedirs("gan_generated_audio", exist_ok=True)
    os.makedirs("gan_spectrogram_plots", exist_ok=True)

    for epoch in range(1, epochs + 1):
        loop = tqdm(dataloader, desc=f"Epoch {epoch}/{epochs}", leave=True)
        for real_specs, labels in loop:
            real_specs = real_specs.to(device)
            labels = labels.to(device)
            batch_size = real_specs.size(0)

            # Labels for loss calculation
            real_labels_tensor = torch.ones(batch_size, 1, device=device)
            fake_labels_tensor = torch.zeros(batch_size, 1, device=device)

            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()
            real_output = discriminator(real_specs, labels)
            loss_D_real = criterion(real_output, real_labels_tensor)

            z = torch.randn(batch_size, latent_dim, device=device)
            fake_specs = generator(z, labels)

            fake_output = discriminator(fake_specs.detach(), labels)
            loss_D_fake = criterion(fake_output, fake_labels_tensor)

            loss_D = loss_D_real + loss_D_fake
            loss_D.backward()
            optimizer_D.step()

            # -----------------
            #  Train Generator
            # -----------------
            optimizer_G.zero_grad()
            output = discriminator(fake_specs, labels)
            loss_G = criterion(output, real_labels_tensor)

            loss_G.backward()
            optimizer_G.step()

            loop.set_postfix(loss_D=loss_D.item(), loss_G=loss_G.item())

        # --- End of Epoch: Generate and save samples ---
        if epoch % 1 == 0:
            print(f"\n--- Generating Samples for Epoch {epoch} ---")
            generator.eval()

            fig, axes = plt.subplots(1, len(categories), figsize=(4 * len(categories), 4))
            if len(categories) == 1:
                axes = [axes]

            for cat_idx, cat_name in enumerate(categories):
                y_cond = F.one_hot(torch.tensor([cat_idx]), num_classes=generator.num_categories).float().to(device)
                z_sample = torch.randn(1, generator.latent_dim).to(device)
                with torch.no_grad():
                    spec_gen_log = generator(z_sample, y_cond)

                spec_gen_log_np = spec_gen_log.squeeze().cpu().numpy()
                axes[cat_idx].imshow(spec_gen_log_np, aspect='auto', origin='lower', cmap='viridis')
                axes[cat_idx].set_title(f'{cat_name} (Epoch {epoch})')
                axes[cat_idx].axis('off')

            plt.tight_layout()
            plt.savefig(f'gan_spectrogram_plots/epoch_{epoch:03d}.png')
            plt.show()
            plt.close(fig)

            for cat_idx, cat_name in enumerate(categories):
                wav = generate_audio_gan(generator, cat_idx, 1, device)
                fname = f"gan_generated_audio/{cat_name}_ep{epoch}.wav"
                save_and_play(wav, sample_rate=22050, filename=fname)

            generator.train()
            print("--- End of Sample Generation ---\n")

        # -------------------------------
        # Save checkpoint for this epoch
        # -------------------------------
        torch.save({
            'epoch': epoch,
            'generator_state': generator.state_dict(),
            'discriminator_state': discriminator.state_dict(),
            'optG_state': optimizer_G.state_dict(),
            'optD_state': optimizer_D.state_dict()
        }, f'checkpoints/gan_epoch_{epoch:03d}.pt')


In [None]:
# ===============================================================
# 5. MAIN EXECUTION BLOCK
# ==============================================================================

if __name__ == '__main__':
    # --- Configuration ---
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    LATENT_DIM = 100 # Standard for GANs
    EPOCHS = 200 # GANs often require more epochs
    BATCH_SIZE = 32
    LEARNING_RATE = 2e-4 # Common learning rate for GANs with Adam

    # --- Paths and Data Setup ---
    BASE_PATH = '/content/drive/MyDrive/organized_dataset'
    TRAIN_PATH = os.path.join(BASE_PATH, 'train')
    train_categories = sorted([d for d in os.listdir(TRAIN_PATH) if os.path.isdir(os.path.join(TRAIN_PATH, d))])
    NUM_CATEGORIES = len(train_categories)

    print(f"Using device: {DEVICE}")
    print(f"Found {NUM_CATEGORIES} categories: {train_categories}")

    train_dataset = TrainAudioSpectrogramDataset(
        root_dir=TRAIN_PATH, categories=train_categories
    )
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

    # --- Initialize Models ---
    generator = CGAN_Generator(LATENT_DIM, NUM_CATEGORIES).to(DEVICE)
    discriminator = CGAN_Discriminator(NUM_CATEGORIES).to(DEVICE)

    # --- Start Training ---
    train_gan(
        generator=generator,
        discriminator=discriminator,
        dataloader=train_loader,
        device=DEVICE,
        categories=train_categories,
        epochs=EPOCHS,
        lr=LEARNING_RATE,
        latent_dim=LATENT_DIM
    )

Using device: cuda
Found 5 categories: ['dog_bark', 'drilling', 'engine_idling', 'siren', 'street_music']


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
Epoch 1/200:  97%|█████████▋| 105/108 [35:50<01:01, 20.48s/it, loss_D=0.0026, loss_G=7.32]


KeyboardInterrupt: 