In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from model import EnergyDistanceDiscriminator, Generator, initialize_weights

def main():
    # Setup device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Hyperparameters
    config = {
        "lr": 5e-5,
        "batch_size": 200,
        "image_size": 64,
        "channels_img": 1,
        "z_dim": 128,
        "num_epochs": 20,
        "features_gen": 64,
        "weight_clip": 0.01,  # Note: Weight clipping is not used in ED discriminator
    }

    # Data preprocessing
    transform = transforms.Compose([
        transforms.Resize(config["image_size"]),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

    # Load Dataset
    dataset = datasets.MNIST(root="dataset/", transform=transform, download=True)
    loader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=True)

    # Initialize models
    gen = Generator(config["z_dim"], config["channels_img"], config["features_gen"]).to(device)
    critic = EnergyDistanceDiscriminator(config["channels_img"]).to(device)
    initialize_weights(gen)
    #initialize_weights(critic)  # Assuming EnergyDistanceDiscriminator also has weights

    # Optimizers
    opt_gen = optim.RMSprop(gen.parameters(), lr=config["lr"])

    # TensorBoard
    writer_loss = SummaryWriter(f"logs_egan/losses")
    step = 0

    # Adapt models for multiple GPUs
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs!")
        gen = nn.DataParallel(gen)
        critic = nn.DataParallel(critic)

    # Main training loop
    for epoch in range(config["num_epochs"]):
        for batch_idx, (data, _) in enumerate(tqdm(loader)):
            data = data.to(device)
            cur_batch_size = data.shape[0]

            # Generate fake images
            noise = torch.randn(cur_batch_size, config["z_dim"], 1, 1).to(device)
            fake = gen(noise)

            # Train Generator with Energy Distance
            loss_gen = critic(data, fake)
            gen.zero_grad()
            loss_gen.backward()
            opt_gen.step()

            # Log the generator loss
            writer_loss.add_scalar("Loss/ED", loss_gen.item(), global_step=step)

            # Periodically log and visualize the training progress
            if batch_idx % 100 == 0:
                print(f"Epoch [{epoch}/{config['num_epochs']}] Batch {batch_idx}/{len(loader)} Batch Size: {cur_batch_size}, loss G: {loss_gen:.4f}")

                with torch.no_grad():
                    fake = gen(noise)
                    img_grid_real = torchvision.utils.make_grid(data[:32], normalize=True)
                    img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
                    writer_loss.add_image("Real Images", img_grid_real, global_step=step)
                    writer_loss.add_image("Fake Images", img_grid_fake, global_step=step)

                step += 1

            # Reset models to training mode
            gen.train()
            #critic.train()  # Assuming the critic might need to switch between modes for certain types

if __name__ == "__main__":
    main()
