In [1]:
import torch
import torch.optim as optim
import torch.autograd as autograd
import torch.nn.functional as F
import os
from tqdm import tqdm

In [2]:
# Import our custom modules
from dataset import get_data_loaders
from model import VAE_WGAN

In [3]:
# ================= CONFIGURATION =================
# Hyperparameters
BATCH_SIZE = 64
LR = 5e-5           # Learning Rate
EPOCHS = 100        # Dataset is small (9k), so 100 epochs is reasonable
Z_DIM = 128         # Latent dimension
LAMBDA_GP = 10      # Gradient Penalty weight (Standard WGAN-GP value)
BETA_KL = 1.0       # Weight for KL Divergence
GAMMA_REC = 200.0    # Weight for Reconstruction Loss (High to prioritize accuracy)
ALPHA_ADV = 0.5     # Weight for Adversarial Loss (Texture realism)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [4]:
# ================= HELPER FUNCTIONS =================

def compute_gradient_penalty(D, real_samples, fake_samples):
    """
    Calculates the gradient penalty loss for WGAN-GP.
    Enforces the 1-Lipschitz constraint.
    """
    # Random weight term for interpolation between real and fake samples
    alpha = torch.rand(real_samples.size(0), 1, 1, 1).to(DEVICE)

    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)

    d_interpolates = D(interpolates)

    fake = torch.ones(real_samples.shape[0], 1).to(DEVICE)

    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [5]:
def loss_function_vae(recon_x, x, mu, logvar):
    """
    Standard VAE Loss = Reconstruction + KL Divergence
    """
    # 1. Reconstruction Loss (L1 or MSE)
    # L1 creates sharper images than MSE
    recon_loss = F.mse_loss(recon_x, x, reduction='mean')

    # 2. KL Divergence
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    kl_div = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

    return recon_loss, kl_div

### ***================= TRAINING LOOP =================***

In [6]:
# ================= TRAINING LOOP =================
print(f"Training on {DEVICE}...")

Training on cuda...


In [7]:
# 1. Load Data
# Note: We assume 'modis_dataset' folder is in current directory
train_loader, _ = get_data_loaders('modis_dataset_brazil', batch_size=BATCH_SIZE)

Total Normal: 9298
Total Fire: 9299
--- Split Summary ---
Train Set (Normal Only): 7438 images
Test Set (Normal):       1860 images
Test Set (Fire):         1860 images
Total Test Set:          3720 images


In [8]:
# 2. Initialize Model
model = VAE_WGAN().to(DEVICE)

In [9]:
# 3. Optimizers
# We need separate optimizers for the Encoder/Decoder (Generator) and Discriminator
opt_vae = optim.Adam(
    list(model.encoder.parameters()) + list(model.decoder.parameters()),
    lr=LR, betas=(0.5, 0.9)
)
opt_disc = optim.Adam(model.discriminator.parameters(), lr=LR, betas=(0.5, 0.9))

In [10]:
# Training State
global_step = 0

In [11]:
for epoch in range(EPOCHS):
    loop = tqdm(train_loader, leave=True)
    loop.set_description(f"Epoch [{epoch+1}/{EPOCHS}]")

    for batch_idx, (real_imgs, _) in enumerate(loop):
        real_imgs = real_imgs.to(DEVICE)
        curr_batch_size = real_imgs.size(0)

        # ============================================
        #  STEP 1: TRAIN DISCRIMINATOR (The Critic)
        # ============================================
        # In WGAN, we update the critic more often than the generator (e.g., n_critic=5)
        # But for VAE-WGANs, 1:1 is often sufficient. We stick to 1:1 for speed.

        opt_disc.zero_grad()

        # Forward pass VAE to get reconstructed images (Fake)
        recon_imgs, _, _ = model(real_imgs)

        # Detach recon_imgs because we don't want to update VAE weights yet
        fake_imgs = recon_imgs.detach()

        # Critic scores
        real_validity = model.discriminator(real_imgs)
        fake_validity = model.discriminator(fake_imgs)

        # Gradient Penalty
        gp = compute_gradient_penalty(model.discriminator, real_imgs, fake_imgs)

        # WGAN Loss: D(fake) - D(real) + lambda * GP
        # We want to minimize this (which maximizes D(real) - D(fake))
        d_loss = torch.mean(fake_validity) - torch.mean(real_validity) + LAMBDA_GP * gp

        d_loss.backward()
        opt_disc.step()

        # ============================================
        #  STEP 2: TRAIN VAE (Encoder + Decoder)
        # ============================================
        opt_vae.zero_grad()

        # Forward pass
        recon_imgs, mu, logvar = model(real_imgs)

        # VAE Losses
        recon_loss, kl_loss = loss_function_vae(recon_imgs, real_imgs, mu, logvar)

        # Adversarial Loss (Generator tries to fool Critic)
        # We want to minimize -D(fake)
        fake_validity = model.discriminator(recon_imgs)
        adv_loss = -torch.mean(fake_validity)

        # Total VAE Loss
        # We weigh Reconstruction heavily so the Anomaly Detection works well
        vae_loss = (1000.0 * recon_loss) + (BETA_KL * kl_loss) + (ALPHA_ADV * adv_loss)


        vae_loss.backward()
        opt_vae.step()

        # Update Progress Bar
        loop.set_postfix(
            Recon=f"{recon_loss.item():.4f}",
            Critic=f"{d_loss.item():.4f}"
        )

        global_step += 1

    # ================= SAVE CHECKPOINT =================
    # Save every 10 epochs and the final one
    if (epoch + 1) % 10 == 0:
        save_path = os.path.join(CHECKPOINT_DIR, f"vae_wgan_epoch_{epoch+1}.pth")
        torch.save(model.state_dict(), save_path)
        print(f"Saved checkpoint: {save_path}")

Epoch [1/100]: 100%|██████████| 117/117 [00:18<00:00,  6.26it/s, Critic=-10.4139, Recon=0.2151]
Epoch [2/100]: 100%|██████████| 117/117 [00:30<00:00,  3.90it/s, Critic=-20.6930, Recon=0.0935]
Epoch [3/100]: 100%|██████████| 117/117 [00:27<00:00,  4.24it/s, Critic=-16.4874, Recon=0.0376]
Epoch [4/100]: 100%|██████████| 117/117 [00:16<00:00,  6.94it/s, Critic=-6.6881, Recon=0.0129]
Epoch [5/100]: 100%|██████████| 117/117 [00:16<00:00,  7.13it/s, Critic=-4.9368, Recon=0.0127]
Epoch [6/100]: 100%|██████████| 117/117 [00:16<00:00,  7.09it/s, Critic=-4.8372, Recon=0.0114]
Epoch [7/100]: 100%|██████████| 117/117 [00:15<00:00,  7.32it/s, Critic=-4.6338, Recon=0.0153]
Epoch [8/100]: 100%|██████████| 117/117 [00:15<00:00,  7.43it/s, Critic=-5.3819, Recon=0.0111]
Epoch [9/100]: 100%|██████████| 117/117 [00:16<00:00,  7.28it/s, Critic=-4.6393, Recon=0.0077]
Epoch [10/100]: 100%|██████████| 117/117 [00:16<00:00,  7.17it/s, Critic=-4.9042, Recon=0.0141]


Saved checkpoint: checkpoints\vae_wgan_epoch_10.pth


Epoch [11/100]: 100%|██████████| 117/117 [00:16<00:00,  7.24it/s, Critic=-4.8760, Recon=0.0074]
Epoch [12/100]: 100%|██████████| 117/117 [00:16<00:00,  7.15it/s, Critic=-4.8014, Recon=0.0101]
Epoch [13/100]: 100%|██████████| 117/117 [00:16<00:00,  7.28it/s, Critic=-5.9214, Recon=0.0100]
Epoch [14/100]: 100%|██████████| 117/117 [00:16<00:00,  7.26it/s, Critic=-5.9843, Recon=0.0097]
Epoch [15/100]: 100%|██████████| 117/117 [00:17<00:00,  6.73it/s, Critic=-3.4153, Recon=0.0059]
Epoch [16/100]: 100%|██████████| 117/117 [00:26<00:00,  4.39it/s, Critic=-5.2086, Recon=0.0154]
Epoch [17/100]: 100%|██████████| 117/117 [00:35<00:00,  3.26it/s, Critic=-4.9095, Recon=0.0119]
Epoch [18/100]: 100%|██████████| 117/117 [00:23<00:00,  5.01it/s, Critic=-6.3497, Recon=0.0133]
Epoch [19/100]: 100%|██████████| 117/117 [00:29<00:00,  4.01it/s, Critic=-4.3956, Recon=0.0075]
Epoch [20/100]: 100%|██████████| 117/117 [00:26<00:00,  4.46it/s, Critic=-6.1292, Recon=0.0097]


Saved checkpoint: checkpoints\vae_wgan_epoch_20.pth


Epoch [21/100]: 100%|██████████| 117/117 [00:25<00:00,  4.55it/s, Critic=-4.7325, Recon=0.0066]
Epoch [22/100]: 100%|██████████| 117/117 [00:32<00:00,  3.60it/s, Critic=-5.2834, Recon=0.0097]
Epoch [23/100]: 100%|██████████| 117/117 [00:35<00:00,  3.27it/s, Critic=-4.7906, Recon=0.0069]
Epoch [24/100]: 100%|██████████| 117/117 [00:40<00:00,  2.86it/s, Critic=-4.7675, Recon=0.0073]
Epoch [25/100]: 100%|██████████| 117/117 [00:37<00:00,  3.08it/s, Critic=-4.8985, Recon=0.0071]
Epoch [26/100]: 100%|██████████| 117/117 [00:16<00:00,  7.18it/s, Critic=-4.9424, Recon=0.0078]
Epoch [27/100]: 100%|██████████| 117/117 [00:19<00:00,  6.01it/s, Critic=-3.4823, Recon=0.0056]
Epoch [28/100]: 100%|██████████| 117/117 [00:35<00:00,  3.28it/s, Critic=-4.7850, Recon=0.0083]
Epoch [29/100]: 100%|██████████| 117/117 [00:32<00:00,  3.56it/s, Critic=-5.2181, Recon=0.0098]
Epoch [30/100]: 100%|██████████| 117/117 [00:33<00:00,  3.54it/s, Critic=-6.0259, Recon=0.0102]


Saved checkpoint: checkpoints\vae_wgan_epoch_30.pth


Epoch [31/100]: 100%|██████████| 117/117 [00:35<00:00,  3.27it/s, Critic=-5.7532, Recon=0.0083]
Epoch [32/100]: 100%|██████████| 117/117 [00:21<00:00,  5.37it/s, Critic=-4.0976, Recon=0.0094]
Epoch [33/100]: 100%|██████████| 117/117 [00:18<00:00,  6.33it/s, Critic=-4.7077, Recon=0.0076]
Epoch [34/100]: 100%|██████████| 117/117 [00:18<00:00,  6.19it/s, Critic=-4.8039, Recon=0.0074]
Epoch [35/100]: 100%|██████████| 117/117 [00:18<00:00,  6.49it/s, Critic=-5.0737, Recon=0.0103]
Epoch [36/100]: 100%|██████████| 117/117 [00:28<00:00,  4.13it/s, Critic=-3.6557, Recon=0.0064]
Epoch [37/100]: 100%|██████████| 117/117 [00:41<00:00,  2.83it/s, Critic=-3.9812, Recon=0.0058]
Epoch [38/100]: 100%|██████████| 117/117 [00:33<00:00,  3.54it/s, Critic=-3.5946, Recon=0.0062]
Epoch [39/100]: 100%|██████████| 117/117 [00:41<00:00,  2.83it/s, Critic=-3.3662, Recon=0.0066]
Epoch [40/100]: 100%|██████████| 117/117 [00:35<00:00,  3.32it/s, Critic=-5.4640, Recon=0.0074]


Saved checkpoint: checkpoints\vae_wgan_epoch_40.pth


Epoch [41/100]: 100%|██████████| 117/117 [00:41<00:00,  2.80it/s, Critic=-4.3549, Recon=0.0065]
Epoch [42/100]: 100%|██████████| 117/117 [00:33<00:00,  3.53it/s, Critic=-4.2003, Recon=0.0064]
Epoch [43/100]: 100%|██████████| 117/117 [00:40<00:00,  2.87it/s, Critic=-3.7687, Recon=0.0069]
Epoch [44/100]: 100%|██████████| 117/117 [00:30<00:00,  3.84it/s, Critic=-4.4351, Recon=0.0081]
Epoch [45/100]: 100%|██████████| 117/117 [00:33<00:00,  3.50it/s, Critic=-4.1797, Recon=0.0058]
Epoch [46/100]: 100%|██████████| 117/117 [00:33<00:00,  3.46it/s, Critic=-3.2497, Recon=0.0062]
Epoch [47/100]: 100%|██████████| 117/117 [00:37<00:00,  3.12it/s, Critic=-4.3183, Recon=0.0080]
Epoch [48/100]: 100%|██████████| 117/117 [00:34<00:00,  3.40it/s, Critic=-3.7957, Recon=0.0076]
Epoch [49/100]: 100%|██████████| 117/117 [00:33<00:00,  3.47it/s, Critic=-3.1187, Recon=0.0062]
Epoch [50/100]: 100%|██████████| 117/117 [00:33<00:00,  3.48it/s, Critic=-4.6473, Recon=0.0072]


Saved checkpoint: checkpoints\vae_wgan_epoch_50.pth


Epoch [51/100]: 100%|██████████| 117/117 [00:29<00:00,  3.93it/s, Critic=-6.3240, Recon=0.0078]
Epoch [52/100]: 100%|██████████| 117/117 [00:18<00:00,  6.38it/s, Critic=-3.7725, Recon=0.0056]
Epoch [53/100]: 100%|██████████| 117/117 [00:16<00:00,  7.13it/s, Critic=-4.8719, Recon=0.0085]
Epoch [54/100]: 100%|██████████| 117/117 [00:16<00:00,  7.11it/s, Critic=-3.2914, Recon=0.0048]
Epoch [55/100]: 100%|██████████| 117/117 [00:16<00:00,  7.12it/s, Critic=-4.1663, Recon=0.0072]
Epoch [56/100]: 100%|██████████| 117/117 [00:16<00:00,  7.04it/s, Critic=-3.7560, Recon=0.0054]
Epoch [57/100]: 100%|██████████| 117/117 [00:16<00:00,  7.09it/s, Critic=-4.3846, Recon=0.0085]
Epoch [58/100]: 100%|██████████| 117/117 [00:16<00:00,  7.05it/s, Critic=-4.9348, Recon=0.0068]
Epoch [59/100]: 100%|██████████| 117/117 [00:16<00:00,  7.13it/s, Critic=-3.6266, Recon=0.0090]
Epoch [60/100]: 100%|██████████| 117/117 [00:16<00:00,  7.08it/s, Critic=-3.4182, Recon=0.0045]


Saved checkpoint: checkpoints\vae_wgan_epoch_60.pth


Epoch [61/100]: 100%|██████████| 117/117 [00:18<00:00,  6.47it/s, Critic=-6.2569, Recon=0.0101]
Epoch [62/100]: 100%|██████████| 117/117 [00:16<00:00,  6.92it/s, Critic=-3.0377, Recon=0.0051]
Epoch [63/100]: 100%|██████████| 117/117 [00:17<00:00,  6.65it/s, Critic=-3.4785, Recon=0.0072]
Epoch [64/100]: 100%|██████████| 117/117 [00:17<00:00,  6.87it/s, Critic=-4.3444, Recon=0.0069]
Epoch [65/100]: 100%|██████████| 117/117 [00:16<00:00,  6.93it/s, Critic=-3.7194, Recon=0.0048]
Epoch [66/100]: 100%|██████████| 117/117 [00:16<00:00,  6.97it/s, Critic=-3.9990, Recon=0.0066]
Epoch [67/100]: 100%|██████████| 117/117 [00:16<00:00,  7.03it/s, Critic=-3.4488, Recon=0.0061]
Epoch [68/100]: 100%|██████████| 117/117 [00:16<00:00,  6.94it/s, Critic=-3.9586, Recon=0.0072]
Epoch [69/100]: 100%|██████████| 117/117 [00:16<00:00,  7.15it/s, Critic=-4.1064, Recon=0.0061]
Epoch [70/100]: 100%|██████████| 117/117 [00:16<00:00,  7.14it/s, Critic=-3.3904, Recon=0.0059]


Saved checkpoint: checkpoints\vae_wgan_epoch_70.pth


Epoch [71/100]: 100%|██████████| 117/117 [00:18<00:00,  6.47it/s, Critic=-2.3244, Recon=0.0047]
Epoch [72/100]: 100%|██████████| 117/117 [00:16<00:00,  7.13it/s, Critic=-4.2423, Recon=0.0079]
Epoch [73/100]: 100%|██████████| 117/117 [00:16<00:00,  6.98it/s, Critic=-4.6422, Recon=0.0069]
Epoch [74/100]: 100%|██████████| 117/117 [00:16<00:00,  7.13it/s, Critic=-5.3878, Recon=0.0093]
Epoch [75/100]: 100%|██████████| 117/117 [00:16<00:00,  6.99it/s, Critic=-4.6053, Recon=0.0067]
Epoch [76/100]: 100%|██████████| 117/117 [00:16<00:00,  7.19it/s, Critic=-4.0779, Recon=0.0064]
Epoch [77/100]: 100%|██████████| 117/117 [00:17<00:00,  6.65it/s, Critic=-3.3121, Recon=0.0054]
Epoch [78/100]: 100%|██████████| 117/117 [00:15<00:00,  7.31it/s, Critic=-3.7277, Recon=0.0058]
Epoch [79/100]: 100%|██████████| 117/117 [00:16<00:00,  7.00it/s, Critic=-3.2604, Recon=0.0049]
Epoch [80/100]: 100%|██████████| 117/117 [00:16<00:00,  7.05it/s, Critic=-3.7584, Recon=0.0052]


Saved checkpoint: checkpoints\vae_wgan_epoch_80.pth


Epoch [81/100]: 100%|██████████| 117/117 [00:18<00:00,  6.45it/s, Critic=-4.5437, Recon=0.0069]
Epoch [82/100]: 100%|██████████| 117/117 [00:16<00:00,  6.99it/s, Critic=-2.3051, Recon=0.0045]
Epoch [83/100]: 100%|██████████| 117/117 [00:16<00:00,  7.21it/s, Critic=-2.8449, Recon=0.0048]
Epoch [84/100]: 100%|██████████| 117/117 [00:16<00:00,  7.05it/s, Critic=-4.3666, Recon=0.0080]
Epoch [85/100]: 100%|██████████| 117/117 [00:16<00:00,  7.07it/s, Critic=-3.3175, Recon=0.0054]
Epoch [86/100]: 100%|██████████| 117/117 [00:17<00:00,  6.81it/s, Critic=-3.1839, Recon=0.0057]
Epoch [87/100]: 100%|██████████| 117/117 [00:16<00:00,  7.03it/s, Critic=-3.0003, Recon=0.0054]
Epoch [88/100]: 100%|██████████| 117/117 [00:15<00:00,  7.33it/s, Critic=-3.1524, Recon=0.0049]
Epoch [89/100]: 100%|██████████| 117/117 [00:16<00:00,  7.23it/s, Critic=-3.9586, Recon=0.0054]
Epoch [90/100]: 100%|██████████| 117/117 [00:16<00:00,  7.16it/s, Critic=-4.8378, Recon=0.0071]


Saved checkpoint: checkpoints\vae_wgan_epoch_90.pth


Epoch [91/100]: 100%|██████████| 117/117 [00:18<00:00,  6.40it/s, Critic=-4.3837, Recon=0.0061]
Epoch [92/100]: 100%|██████████| 117/117 [00:16<00:00,  7.12it/s, Critic=-4.2317, Recon=0.0068]
Epoch [93/100]: 100%|██████████| 117/117 [00:16<00:00,  6.89it/s, Critic=-3.0403, Recon=0.0067]
Epoch [94/100]: 100%|██████████| 117/117 [00:16<00:00,  7.14it/s, Critic=-4.3444, Recon=0.0072]
Epoch [95/100]: 100%|██████████| 117/117 [00:17<00:00,  6.88it/s, Critic=-3.8765, Recon=0.0049]
Epoch [96/100]: 100%|██████████| 117/117 [00:16<00:00,  7.29it/s, Critic=-3.5523, Recon=0.0049]
Epoch [97/100]: 100%|██████████| 117/117 [00:16<00:00,  7.15it/s, Critic=-2.5455, Recon=0.0038]
Epoch [98/100]: 100%|██████████| 117/117 [00:16<00:00,  6.91it/s, Critic=-3.7409, Recon=0.0054]
Epoch [99/100]: 100%|██████████| 117/117 [00:16<00:00,  6.91it/s, Critic=-2.9510, Recon=0.0080]
Epoch [100/100]: 100%|██████████| 117/117 [00:17<00:00,  6.78it/s, Critic=-3.5842, Recon=0.0082]

Saved checkpoint: checkpoints\vae_wgan_epoch_100.pth





In [12]:
print("Training Complete.")
torch.save(model.state_dict(), "vae_wgan_final.pth")

Training Complete.
