In [None]:
!pip install torch torchaudio torchvision transformers



In [None]:
# ==============================================================================
# 0. IMPORTS & INITIAL SETUP
# ==============================================================================
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
from torch.utils.data import Dataset, DataLoader
import torch.autograd as autograd
import torch.cuda.amp as amp  # For Mixed Precision
import torchvision.utils as vutils # For saving image grids

from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import Audio, display
import warnings
warnings.filterwarnings("ignore")

# --- Optimizations ---
torch.backends.cudnn.benchmark = True # cuDNN speedup

In [6]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [7]:
# ==============================================================================
# 2. GAN MODEL DEFINITIONS
# ==============================================================================

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

class WGAN_Generator(nn.Module):
    def __init__(self, latent_dim, num_categories, spec_shape=(128, 512), embed_dim=16):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_categories = num_categories
        self.spec_shape = spec_shape
        H, W = spec_shape

        self.start_h, self.start_w = H // 16, W // 16

        self.label_emb = nn.Embedding(num_categories, embed_dim)

        self.z_proj = nn.Linear(latent_dim, 256 * self.start_h * self.start_w)
        self.emb_proj = nn.Linear(embed_dim, 256 * self.start_h * self.start_w)

        self.net = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.ReLU()
        )

    def forward(self, z, y_idx):
        z_proj = self.z_proj(z).view(-1, 256, self.start_h, self.start_w)
        emb = self.label_emb(y_idx.argmax(dim=1))
        emb_proj = self.emb_proj(emb).view(-1, 256, self.start_h, self.start_w)

        x = torch.cat([z_proj, emb_proj], dim=1)
        return self.net(x)

In [8]:
class WGAN_Critic(nn.Module):
    def __init__(self, num_categories, spec_shape=(128, 512), embed_dim=16):
        super().__init__()
        self.num_categories = num_categories
        self.spec_shape = spec_shape
        self.H, self.W = spec_shape

        self.label_emb = nn.Embedding(num_categories, embed_dim)
        self.emb_proj = nn.Linear(embed_dim, 1 * self.H * self.W)

        self.net = nn.Sequential(
            nn.Conv2d(2, 32, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 1, kernel_size=(self.H // 16, self.W // 16), stride=1, padding=0)
        )

    def forward(self, spec, y_idx):
        emb = self.label_emb(y_idx.argmax(dim=1))
        emb_map = self.emb_proj(emb).view(-1, 1, self.H, self.W)

        x = torch.cat([spec, emb_map], dim=1)
        x = self.net(x)
        return x.view(-1, 1)

In [9]:
# ==============================================================================
# 3. UTILITY FUNCTIONS
# ==============================================================================

def generate_audio_gan(generator, category_idx, num_samples, device, sample_rate=22050):
    generator.eval()
    num_categories = generator.num_categories
    latent_dim = generator.latent_dim

    y = F.one_hot(torch.tensor([category_idx] * num_samples), num_classes=num_categories).float().to(device)
    z = torch.randn(num_samples, latent_dim, device=device)

    with torch.no_grad():
        with amp.autocast(enabled=(device == 'cuda')):
            log_spec_gen = generator(z, y)

    spec_gen = torch.expm1(log_spec_gen)
    spec_gen = spec_gen.squeeze(1)

    spec_gen = spec_gen.float()
    inverse_mel = torchaudio.transforms.InverseMelScale(
        n_stft=1024 // 2 + 1, n_mels=128, sample_rate=sample_rate
    ).to(device)
    linear_spec = inverse_mel(spec_gen)

    griffin = torchaudio.transforms.GriffinLim(
        n_fft=1024, hop_length=256, win_length=1024, n_iter=32
    ).to(device)

    waveform = griffin(linear_spec)
    return waveform.cpu()

In [10]:
def save_and_play(wav, sample_rate, filename):
    if wav.dim() > 2: wav = wav.squeeze(0)
    torchaudio.save(filename, wav, sample_rate=sample_rate)
    print(f"Saved to {filename}")
    display(Audio(data=wav.numpy(), rate=sample_rate))

In [11]:
# ==============================================================================
# 3.1 Checkpointing
# ==============================================================================
def save_checkpoint(g, c, g_optim, c_optim, epoch, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    state = {
        'epoch': epoch,
        'g_state_dict': g.state_dict(),
        'c_state_dict': c.state_dict(),
        'g_optim_state_dict': g_optim.state_dict(),
        'c_optim_state_dict': c_optim.state_dict(),
    }
    torch.save(state, path)
    print(f"Checkpoint saved to {path}")

def load_checkpoint(g, c, g_optim, c_optim, path, device):
    if not os.path.exists(path):
        print(f"No checkpoint found at {path}. Starting from scratch.")
        return 0

    state = torch.load(path, map_location=device)
    g.load_state_dict(state['g_state_dict'])
    c.load_state_dict(state['c_state_dict'])
    g_optim.load_state_dict(state['g_optim_state_dict'])
    c_optim.load_state_dict(state['c_optim_state_dict'])
    start_epoch = state['epoch'] + 1
    print(f"Loaded checkpoint from {path}. Resuming at epoch {start_epoch}.")
    return start_epoch

In [12]:
def compute_gradient_penalty(critic, real_specs, fake_specs, labels, device, gp_weight=10.0):
    B = real_specs.size(0)
    alpha = torch.rand(B, 1, 1, 1, device=device)

    interpolated = (alpha * real_specs + (1 - alpha) * fake_specs).requires_grad_(True)

    c_interpolated = critic(interpolated, labels)

    grad_outputs = torch.ones(c_interpolated.size(), device=device, requires_grad=False)
    gradients = autograd.grad(
        outputs=c_interpolated,
        inputs=interpolated,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradients = gradients.view(B, -1)
    gradient_norm = gradients.norm(2, dim=1)
    gradient_penalty = gp_weight * ((gradient_norm - 1) ** 2).mean()
    return gradient_penalty

In [13]:
# ==============================================================================
# 4. GAN TRAINING FUNCTION
# ==============================================================================
def train_wgan_gp(
    generator, critic, dataloader, device, categories, epochs, lr, betas,
    latent_dim, n_critic, gp_weight, use_amp, checkpoint_path, sample_dir
):
    optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=betas)
    optimizer_D = optim.Adam(critic.parameters(), lr=lr, betas=betas)

    scaler = amp.GradScaler(enabled=use_amp)

    os.makedirs(sample_dir, exist_ok=True)
    os.makedirs("gan_generated_audio", exist_ok=True)
    os.makedirs("gan_spectrogram_plots", exist_ok=True)

    start_epoch = load_checkpoint(
        generator, critic, optimizer_G, optimizer_D, checkpoint_path, device
    )

    fixed_noise = torch.randn(len(categories), latent_dim, device=device)
    fixed_labels = F.one_hot(torch.arange(len(categories)), num_classes=len(categories)).float().to(device)

    for epoch in range(start_epoch, epochs + 1):
        loop = tqdm(dataloader, desc=f"Epoch {epoch}/{epochs}", leave=True)
        for real_specs, labels in loop:
            real_specs = real_specs.to(device)
            labels = labels.to(device)
            batch_size = real_specs.size(0)

            # ---------------------
            #  Train Critic
            # ---------------------
            for _ in range(n_critic):
                optimizer_D.zero_grad()
                noise = torch.randn(batch_size, latent_dim, device=device)

                with amp.autocast(enabled=use_amp):
                    with torch.no_grad():
                        fake_specs = generator(noise, labels).detach()

                    real_output = critic(real_specs, labels)
                    fake_output = critic(fake_specs, labels)
                    gp = compute_gradient_penalty(critic, real_specs, fake_specs, labels, device, gp_weight)
                    loss_D = fake_output.mean() - real_output.mean() + gp

                scaler.scale(loss_D).backward()
                scaler.step(optimizer_D)
                scaler.update()

            # -----------------
            #  Train Generator
            # -----------------
            optimizer_G.zero_grad()

            with amp.autocast(enabled=use_amp):
                noise = torch.randn(batch_size, latent_dim, device=device)
                fake_specs = generator(noise, labels)
                fake_output = critic(fake_specs, labels)
                loss_G = -fake_output.mean()

            scaler.scale(loss_G).backward()
            scaler.step(optimizer_G)
            scaler.update()

            loop.set_postfix(D_Loss=loss_D.item(), G_Loss=loss_G.item(), GP=gp.item())

        if epoch % 20 == 0:
          # --- End of Epoch: Generate and save samples ---
          print(f"\n--- Generating Samples for Epoch {epoch} ---")
          generator.eval()

          with torch.no_grad():
              with amp.autocast(enabled=use_amp):
                  fake_specs_grid = generator(fixed_noise, fixed_labels).cpu()
              vutils.save_image(
                  fake_specs_grid,
                  os.path.join(sample_dir, f"epoch_{epoch:03d}.png"),
                  nrow=len(categories),
                  normalize=True,
                  value_range=(0, fake_specs_grid.max().item())
              )

          fig, axes = plt.subplots(1, len(categories), figsize=(4 * len(categories), 4))
          if len(categories) == 1: axes = [axes]
          for cat_idx, cat_name in enumerate(categories):
              spec_gen_log_np = fake_specs_grid[cat_idx].squeeze().numpy()
              axes[cat_idx].imshow(spec_gen_log_np, aspect='auto', origin='lower', cmap='viridis')
              axes[cat_idx].set_title(f'{cat_name} (Epoch {epoch})')
              axes[cat_idx].axis('off')
          plt.tight_layout()
          plt.savefig(os.path.join(sample_dir, f"epoch_gan_spectrogram_plot_{epoch:03d}.png"))
          plt.show()
          plt.close(fig)

          for cat_idx, cat_name in enumerate(categories):
              wav = generate_audio_gan(generator, cat_idx, 1, device)
              fname = os.path.join(sample_dir,f"gan_generated_audio/{cat_name}_ep{epoch}.wav")
              save_and_play(wav, sample_rate=22050, filename=fname)

        generator.train()

        save_checkpoint(
            generator, critic, optimizer_G, optimizer_D, epoch, checkpoint_path
        )
        print("--- End of Epoch ---\n")

In [14]:
import glob

class PrecomputedDataset(Dataset):
    """
    This dataset is extremely fast. It just loads pre-computed
    spectrograms from .pt files.
    """
    def __init__(self, root_dir, categories):
        self.root_dir = root_dir
        self.categories = categories
        self.file_list = []

        print("Populating file list from pre-computed dataset...")
        for i, category in enumerate(self.categories):
            path = os.path.join(self.root_dir, category)
            if not os.path.isdir(path):
                print(f"Warning: Directory not found: {path}")
                continue

            # Find all pre-computed .pt files
            files = glob.glob(os.path.join(path, "*.pt"))
            self.file_list.extend(files)

        print(f"Found {len(self.file_list)} pre-computed spectrograms.")
        if len(self.file_list) == 0:
            print("\n*** WARNING: No .pt files found! ***")
            print(f"Did you run the pre-computation cell and set the correct PRECOMPUTED_PATH?")


    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        # Load the saved tensor pair (spec, label)
        try:
            spec, label = torch.load(self.file_list[idx])
            return spec, label
        except Exception as e:
            print(f"\nError loading pre-computed file {self.file_list[idx]}: {e}")
            # Return a dummy sample that matches shape
            return torch.zeros(1, 128, 256), F.one_hot(torch.tensor(0), num_classes=len(self.categories)).float()

In [None]:
# ==============================================================================
# 6. MAIN EXECUTION BLOCK
# ==============================================================================

if __name__ == '__main__':
    # --- Configuration ---
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    LATENT_DIM = 100
    EPOCHS = 1000
    SAMPLE_RATE = 22050

    # --- WGAN-GP & Optimization Config ---
    LEARNING_RATE = 1e-4
    BETAS = (0.5, 0.9)
    N_CRITIC = 5
    GP_WEIGHT = 10
    USE_AMP = True

    # --- *** OPTIMIZATION: Increased Batch Size *** ---
    BATCH_SIZE = 64

    # --- *** OPTIMIZATION: Shorter Spectrograms *** ---
    MAX_FRAMES = 256 # Was 512, must match pre-computation
    SPEC_SHAPE = (128, MAX_FRAMES) # (H, W)

    # --- Paths and Data Setup ---\
    BASE_PATH = "/content/drive/MyDrive/decibel_duel"
    TRAIN_PATH = os.path.join(BASE_PATH, 'train/train')
    PRECOMPUTED_PATH = os.path.join(BASE_PATH, 'train/precompute')
    CHECKPOINT_PATH = os.path.join(BASE_PATH, 'checkpoints/wgan_audio.pth.tar')
    SAMPLE_DIR = os.path.join(BASE_PATH, 'samples')

    train_categories = sorted([d for d in os.listdir(TRAIN_PATH) if os.path.isdir(os.path.join(TRAIN_PATH, d))])
    NUM_CATEGORIES = len(train_categories)

    print(f"Using device: {DEVICE}")
    print(f"Found {NUM_CATEGORIES} categories: {train_categories}")
    print(f"Spectrogram Shape: {SPEC_SHAPE}")
    print(f"Batch Size: {BATCH_SIZE}")

    train_dataset = PrecomputedDataset(
        root_dir=PRECOMPUTED_PATH,
        categories=train_categories
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        drop_last=True # Good for GANs
    )

    generator = WGAN_Generator(
        LATENT_DIM, NUM_CATEGORIES, spec_shape=SPEC_SHAPE
    ).to(DEVICE)
    critic = WGAN_Critic(
        NUM_CATEGORIES, spec_shape=SPEC_SHAPE
    ).to(DEVICE)

    generator.apply(weights_init)
    critic.apply(weights_init)

    # --- Start Training ---\
    train_wgan_gp(
        generator=generator,
        critic=critic,
        dataloader=train_loader,
        device=DEVICE,
        categories=train_categories,
        epochs=EPOCHS,
        lr=LEARNING_RATE,
        betas=BETAS,
        latent_dim=LATENT_DIM,
        n_critic=N_CRITIC,
        gp_weight=GP_WEIGHT,
        use_amp=USE_AMP,
        checkpoint_path=CHECKPOINT_PATH,
        sample_dir=SAMPLE_DIR
    )

Using device: cuda
Found 5 categories: ['dog_bark', 'drilling', 'engine_idling', 'siren', 'street_music']
Spectrogram Shape: (128, 256)
Batch Size: 64
Populating file list from pre-computed dataset...
Found 3450 pre-computed spectrograms.
Loaded checkpoint from /content/drive/MyDrive/decibel_duel/checkpoints/wgan_audio.pth.tar. Resuming at epoch 390.


Epoch 390/1000:  68%|██████▊   | 36/53 [07:03<01:51,  6.56s/it, D_Loss=-45.6, GP=14.3, G_Loss=196]