In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from model import EnergyDistanceDiscriminator, Discriminator, Generator, initialize_weights

def train_discriminator(real, device, disc, opt_disc, gen, z_dim, writer, step):
    noise = torch.randn(real.size(0), z_dim, 1, 1).to(device)
    fake = gen(noise)
    disc_real = disc(real).reshape(-1)
    disc_fake = disc(fake).reshape(-1)
    loss_disc = -(torch.mean(disc_real) - torch.mean(disc_fake))
    disc.zero_grad()
    loss_disc.backward(retain_graph=True)
    opt_disc.step()

    # Weight clipping
    for p in disc.parameters():
        p.data.clamp_(-config["weight_clip"], config["weight_clip"])

    writer.add_scalar("Loss/Discriminator", loss_disc.item(), global_step=step)

    return loss_disc

def train_generator(device, gen, disc, opt_gen, z_dim, writer, step):
    noise = torch.randn(config["batch_size"], z_dim, 1, 1).to(device)
    fake = gen(noise)
    disc_fake = disc(fake).reshape(-1)
    loss_gen = -torch.mean(disc_fake)
    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()

    writer.add_scalar("Loss/Generator", loss_gen.item(), global_step=step)

    return loss_gen

def main():
    # Setup device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    step = 0
    
    # Hyperparameters
    config = {
        "lr": 2e-4,
        "batch_size": 128,
        "image_size": 64,
        "channels_img": 1,
        "z_dim": 128,
        "num_epochs": 3,
        "features_disc": 64,
        "features_gen": 64,
        "critic_iterations": 5,
        "weight_clip": 0.01,
    }

    # Data preprocessing
    transform = transforms.Compose([
        transforms.Resize(config["image_size"]),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

    # Load Dataset
    dataset = datasets.MNIST(root="dataset/", transform=transform, download=True)
    loader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=True)

    # Initialize models
    gen = Generator(config["z_dim"], config["channels_img"], config["features_gen"]).to(device)
    disc = Discriminator(config["channels_img"], config["features_disc"]).to(device)
    energy_disc = EnergyDistanceDiscriminator(config["channels_img"]).to(device)
    initialize_weights(gen)
    initialize_weights(disc)

    # Optimizers
    opt_gen = optim.RMSprop(gen.parameters(), lr=config["lr"])
    opt_disc = optim.RMSprop(disc.parameters(), lr=config["lr"])

    # TensorBoard
    writer = SummaryWriter("logs_wgan")
    fixed_noise = torch.randn(32, config["z_dim"], 1, 1).to(device)
    
    # Use all available GPU
    if torch.cuda.device_count() > 1:
        print(f"Let's use {torch.cuda.device_count()} GPUs!")
        gen = nn.DataParallel(gen)
        disc = nn.DataParallel(disc)
        energy_disc = nn.DataParallel(energy_disc)
    
    # Move models to the appropriate device
    gen.to(device)
    disc.to(device)
    energy_disc.to(device)
    
    for epoch in range(config["num_epochs"]):
        for batch_idx, (real, _) in enumerate(tqdm(loader)):
            real = real.to(device)

            # Train Discriminator
            for _ in range(config["critic_iterations"]):
                loss_disc = train_discriminator(real, device, disc, opt_disc, gen, config["z_dim"], writer, step)

            # Train Generator
            loss_gen = train_generator(device, gen, disc, opt_gen, config["z_dim"], writer, step)
            
            # Logging
            if batch_idx % 100 == 0:
                print(f"Epoch [{epoch}/{config['num_epochs']}] Batch {batch_idx}/{len(loader)} Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}")

                with torch.no_grad():
                    fake = gen(fixed_noise)
                    img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                    img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
                    writer.add_image("Real", img_grid_real, global_step=step)
                    writer.add_image("Fake", img_grid_fake, global_step=step)

                step += 1
                

if __name__ == "__main__":
    main()
