In [None]:
import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, noise_dim, img_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, img_dim),
            nn.Tanh()  # Outputs normalized to [-1, 1]
        )

    def forward(self, x):
        return self.model(x)


In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  # Outputs a probability
        )

    def forward(self, x):
        return self.model(x)


In [None]:
import torch.optim as optim

# Parameters
noise_dim = 100
img_dim = 28 * 28  # For MNIST images
lr = 0.0002
epochs = 100

# Models
generator = Generator(noise_dim, img_dim)
discriminator = Discriminator(img_dim)

# Loss and optimizers
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=lr)
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr)

for epoch in range(epochs):
    for real_images, _ in data_loader:
        # Flatten images
        real_images = real_images.view(-1, img_dim)
        
        # Train Discriminator
        real_labels = torch.ones(real_images.size(0), 1)
        fake_labels = torch.zeros(real_images.size(0), 1)
        
        z = torch.randn(real_images.size(0), noise_dim)
        fake_images = generator(z)
        
        real_loss = criterion(discriminator(real_images), real_labels)
        fake_loss = criterion(discriminator(fake_images.detach()), fake_labels)
        d_loss = real_loss + fake_loss
        
        optimizer_d.zero_grad()
        d_loss.backward()
        optimizer_d.step()
        
        # Train Generator
        z = torch.randn(real_images.size(0), noise_dim)
        fake_images = generator(z)
        g_loss = criterion(discriminator(fake_images), real_labels)
        
        optimizer_g.zero_grad()
        g_loss.backward()
        optimizer_g.step()
        
    print(f"Epoch [{epoch+1}/{epochs}] - D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
