In [1]:
import torch
import torch.nn as nn
import torch.optim as optim


In [2]:
class Generator(nn.Module):
    def __init__(self, noise_dim=64, out_dim=784):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim, 128),
            nn.ReLU(),
            nn.Linear(128, out_dim),
            nn.Tanh()
        )
    def forward(self, z): return self.net(z)


In [3]:
class Discriminator(nn.Module):
    def __init__(self, in_dim=784):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
    def forward(self, x): return self.net(x)


In [4]:
noise_dim = 64
G, D = Generator(noise_dim), Discriminator()
loss_fn = nn.BCELoss()
g_opt, d_opt = optim.Adam(G.parameters(), lr=0.001), optim.Adam(D.parameters(), lr=0.001)


  from .autonotebook import tqdm as notebook_tqdm


In [6]:
# Dummy training loop
for epoch in range(1):
    z = torch.randn(16, noise_dim)
    fake = G(z)
    real = torch.randn(16, 784)

    d_loss_real = loss_fn(D(real), torch.ones(16,1))
    d_loss_fake = loss_fn(D(fake.detach()), torch.zeros(16,1))
    d_loss = (d_loss_real + d_loss_fake)/2

    D.zero_grad(); d_loss.backward(); d_opt.step()

    g_loss = loss_fn(D(fake), torch.ones(16,1))
    G.zero_grad(); g_loss.backward(); g_opt.step()

print("One epoch GAN training completed.")

One epoch GAN training completed.
