In [None]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import os
from models import VAE_WGAN

In [None]:
# Configuration
BATCH_SIZE = 64
LR = 0.0002
EPOCHS = 50
LATENT_DIM = 128
# Point this to your folder containing the "normal_reference" subfolder
DATA_PATH = 'modis_dataset/normal_reference' 

In [None]:
# 1. Prepare Data
# We use ImageFolder, so your data must be inside a subfolder 
# Structure: modis_dataset/normal_reference/images/img1.png
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # Normalize to [-1, 1]
])

In [None]:
# Create a dummy root folder structure if needed or point to parent
# PyTorch ImageFolder expects: root/class_name/image.png
# If your images are just in 'normal_reference/', point ImageFolder to 'modis_dataset/' 
# and ensure 'normal_reference' is the only folder inside relevant for training.
dataset = datasets.ImageFolder(root='modis_dataset', transform=transform)

In [None]:
# Filter: We only want the NORMAL class for training!
# Assuming 'normal_reference' is class 0 or 1. 
# It is often easier to just point to a folder containing ONLY normal images.
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
# 2. Initialize Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE_WGAN(latent_dim=LATENT_DIM).to(device)

In [None]:
# Optimizers
optimizer_enc = optim.Adam(model.encoder.parameters(), lr=LR)
optimizer_dec = optim.Adam(model.decoder.parameters(), lr=LR)
optimizer_dis = optim.Adam(model.discriminator.parameters(), lr=LR)

In [None]:
# Loss Functions
criterion_recon = torch.nn.MSELoss()
criterion_adv = torch.nn.BCELoss()

In [None]:
print(f"Training on {device}...")

In [None]:
# 3. Training Loop
for epoch in range(EPOCHS):
    total_loss = 0
    for i, (imgs, _) in enumerate(train_loader):
        imgs = imgs.to(device)
        
        # --- A. Train Discriminator ---
        # Real images = Label 1, Fake (Reconstructed) = Label 0
        optimizer_dis.zero_grad()
        
        # 1. Real
        real_labels = torch.ones(imgs.size(0)).to(device)
        fake_labels = torch.zeros(imgs.size(0)).to(device)
        
        outputs_real = model.discriminator(imgs)
        d_loss_real = criterion_adv(outputs_real, real_labels)
        
        # 2. Fake (Reconstruct images)
        mu, logvar = model.encoder(imgs)
        z = model.reparameterize(mu, logvar)
        recon_imgs = model.decoder(z)
        
        outputs_fake = model.discriminator(recon_imgs.detach()) # Detach to stop grad to generator
        d_loss_fake = criterion_adv(outputs_fake, fake_labels)
        
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_dis.step()

        # --- B. Train Encoder & Decoder (Generator) ---
        optimizer_enc.zero_grad()
        optimizer_dec.zero_grad()
        
        # 1. Reconstruction Loss
        # We want the output to look like the input
        recon_loss = criterion_recon(recon_imgs, imgs)
        
        # 2. KL Divergence (Standard VAE Loss)
        # Forces latent space to be normal distribution
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        kl_loss /= (BATCH_SIZE * 64 * 64) # Normalize
        
        # 3. Adversarial Loss
        # We want to fool the discriminator (make it predict 1 for fakes)
        outputs_fake_for_gen = model.discriminator(recon_imgs)
        adv_loss = criterion_adv(outputs_fake_for_gen, real_labels)
        
        # Total Generator Loss
        # Weights (gamma) adjust importance. 
        # Typically Reconstruction is highest priority.
        g_loss = (10 * recon_loss) + (0.1 * kl_loss) + adv_loss
        
        g_loss.backward()
        optimizer_enc.step()
        optimizer_dec.step()
        
        total_loss += g_loss.item()

    print(f"Epoch [{epoch+1}/{EPOCHS}] Loss: {total_loss/len(train_loader):.4f}")

In [None]:
# Save the trained model
torch.save(model.state_dict(), "vae_wgan_brazil_fire.pth")
print("Model saved!")