In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
import os

# Settings
latent_dim = 100
epochs = 50
batch_size = 128
sample_dir = 'gan_samples'

os.makedirs(sample_dir, exist_ok=True)

# Data Loader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # normalize to [-1, 1]
])

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(root='.', train=True, transform=transform, download=True),
    batch_size=batch_size,
    shuffle=True
)

# Generator
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 28 * 28),
            nn.Tanh()
        )

    def forward(self, z):
        out = self.model(z)
        return out.view(z.size(0), 1, 28, 28)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        return self.model(img_flat)

# Initialize models
G = Generator()
D = Discriminator()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G.to(device)
D.to(device)

# Loss and optimizers
loss_fn = nn.BCELoss()
g_optimizer = optim.Adam(G.parameters(), lr=0.0002)
d_optimizer = optim.Adam(D.parameters(), lr=0.0002)

# Training loop
for epoch in range(epochs):
    for real_imgs, _ in train_loader:
        real_imgs = real_imgs.to(device)
        batch_size = real_imgs.size(0)

        # Labels
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # Train Discriminator
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = G(z)
        real_loss = loss_fn(D(real_imgs), real_labels)
        fake_loss = loss_fn(D(fake_imgs.detach()), fake_labels)
        d_loss = real_loss + fake_loss

        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # Train Generator
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = G(z)
        g_loss = loss_fn(D(fake_imgs), real_labels)

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

    # Save samples
    print(f"Epoch [{epoch+1}/{epochs}]  D Loss: {d_loss.item():.4f}  G Loss: {g_loss.item():.4f}")
    with torch.no_grad():
        test_z = torch.randn(64, latent_dim).to(device)
        generated = G(test_z)
        save_image(generated, f"{sample_dir}/epoch_{epoch+1:03d}.png", normalize=True, nrow=8)


100%|██████████| 9.91M/9.91M [00:00<00:00, 17.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 479kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.44MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 11.5MB/s]


Epoch [1/50]  D Loss: 0.0391  G Loss: 6.5052
Epoch [2/50]  D Loss: 0.0247  G Loss: 10.2173
Epoch [3/50]  D Loss: 0.2784  G Loss: 7.0374
Epoch [4/50]  D Loss: 0.3516  G Loss: 9.4681
Epoch [5/50]  D Loss: 0.7943  G Loss: 3.9339
Epoch [6/50]  D Loss: 0.8930  G Loss: 2.4434
Epoch [7/50]  D Loss: 0.9027  G Loss: 1.7047
Epoch [8/50]  D Loss: 0.4572  G Loss: 2.3680
Epoch [9/50]  D Loss: 0.6267  G Loss: 2.1660
Epoch [10/50]  D Loss: 0.9494  G Loss: 1.4111
Epoch [11/50]  D Loss: 1.3591  G Loss: 0.9439
Epoch [12/50]  D Loss: 0.4286  G Loss: 2.1348
Epoch [13/50]  D Loss: 0.2814  G Loss: 3.0554
Epoch [14/50]  D Loss: 0.7245  G Loss: 3.0511
Epoch [15/50]  D Loss: 0.8021  G Loss: 2.3209
Epoch [16/50]  D Loss: 0.6389  G Loss: 5.6900
Epoch [17/50]  D Loss: 0.3148  G Loss: 4.9226
Epoch [18/50]  D Loss: 0.2203  G Loss: 6.1628
Epoch [19/50]  D Loss: 0.0330  G Loss: 6.8429
Epoch [20/50]  D Loss: 0.1397  G Loss: 4.8779
Epoch [21/50]  D Loss: 0.0679  G Loss: 10.8177
Epoch [22/50]  D Loss: 0.3811  G Loss: 7.

In [2]:
import imageio
import glob

images = []
for filename in sorted(glob.glob('gan_samples/epoch_*.png')):
    images.append(imageio.imread(filename))
imageio.mimsave('gan_training.gif', images, duration=0.5)


  images.append(imageio.imread(filename))


In [2]:
# Assuming we've stored loss values during training
plt.figure(figsize=(10, 5))
plt.plot(d_losses, label='Discriminator Loss')
plt.plot(g_losses, label='Generator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.title('GAN Training Loss')
plt.savefig('loss_curve.png')
plt.close()

NameError: name 'plt' is not defined