# 11. GANs

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/pytorch_tutorial/blob/main/11_gan/demo.ipynb)

---

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

## Generator & Discriminator

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 784),
            nn.Tanh()
        )
    def forward(self, z):
        return self.main(z).view(-1, 1, 28, 28)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.main(x)

G = Generator()
D = Discriminator()
print('Generator:', sum(p.numel() for p in G.parameters()))
print('Discriminator:', sum(p.numel() for p in D.parameters()))

In [None]:
# Generate random images
z = torch.randn(16, 100)
fake = G(z)

fig, axes = plt.subplots(2, 8, figsize=(12, 3))
for i, ax in enumerate(axes.flat):
    ax.imshow(fake[i].squeeze().detach(), cmap='gray')
    ax.axis('off')
plt.suptitle('Random (Untrained) Generator Output')
plt.show()