In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [4]:
# Discriminator Network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(x.size(0), 784)
        validity = self.model(x)
        return validity

In [5]:
# Generator Network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 784),
            nn.Tanh()
        )

    def forward(self, x):
        img = self.model(x)
        img = img.view(img.size(0), 1, 28, 28)
        return img

#### define the training loop for our GAN. 

- We'll train the discriminator and generator networks alternately. 

- In each iteration, we'll feed real and fake images to the discriminator, compute the losses, and update the networks accordingly.

In [6]:
# Initialize the networks
discriminator = Discriminator()
generator     = Generator()

In [7]:
# Loss function and optimizers
adversarial_loss = nn.BCELoss()

optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)

In [8]:
# Training loop
num_epochs = 200
batch_size = 100

In [9]:
# Load the MNIST dataset
transform = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize(mean=(0.5,), std=(0.5,))
])

dataset       = datasets.MNIST(root      = r"E:\AI-DATASETS\pytorch\mnist-train", 
                                train    = True, 
                                transform= transform,
                                download = True)

#dataset    = datasets.MNIST(root='data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [10]:
for epoch in range(num_epochs):
    
    # for a batch of 100
    for i, (real_images, _) in enumerate(dataloader):
        
        # Adversarial ground truth labels
        valid = torch.ones(real_images.size(0),  1)
        fake  = torch.zeros(real_images.size(0), 1)

        # Train the discriminator
        optimizer_D.zero_grad()

        # Real images
        real_images = real_images.view(real_images.size(0), -1)
        real_pred   = discriminator(real_images)
        d_real_loss = adversarial_loss(real_pred, valid)

        # Fake images
        z = torch.randn(real_images.size(0), 100)
        fake_images = generator(z)
        fake_pred   = discriminator(fake_images.detach())
        d_fake_loss = adversarial_loss(fake_pred, fake)

        # Total discriminator loss
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        optimizer_D.step()

        # Train the generator
        optimizer_G.zero_grad()

        z           = torch.randn(real_images.size(0), 100)
        fake_images = generator(z)
        fake_pred   = discriminator(fake_images)
        g_loss      = adversarial_loss(fake_pred, valid)

        g_loss.backward()
        optimizer_G.step()

        if i % 100 == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}], Batch Step [{i}/{len(dataloader)}],"
                f" D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}"
            )

Epoch [0/200], Batch Step [0/600], D_loss: 1.3756, G_loss: 0.6617
Epoch [0/200], Batch Step [100/600], D_loss: 0.1239, G_loss: 2.8561
Epoch [0/200], Batch Step [200/600], D_loss: 0.7386, G_loss: 1.0211
Epoch [0/200], Batch Step [300/600], D_loss: 0.0309, G_loss: 4.6599
Epoch [0/200], Batch Step [400/600], D_loss: 0.2189, G_loss: 5.9777
Epoch [0/200], Batch Step [500/600], D_loss: 0.0260, G_loss: 5.1129
Epoch [1/200], Batch Step [0/600], D_loss: 0.2660, G_loss: 3.7912
Epoch [1/200], Batch Step [100/600], D_loss: 0.1805, G_loss: 4.1844
Epoch [1/200], Batch Step [200/600], D_loss: 0.1618, G_loss: 5.4391
Epoch [1/200], Batch Step [300/600], D_loss: 0.3433, G_loss: 2.8980
Epoch [1/200], Batch Step [400/600], D_loss: 0.7820, G_loss: 1.5787
Epoch [1/200], Batch Step [500/600], D_loss: 0.1943, G_loss: 3.0256
Epoch [2/200], Batch Step [0/600], D_loss: 0.3072, G_loss: 3.4207
Epoch [2/200], Batch Step [100/600], D_loss: 0.6036, G_loss: 2.6584
Epoch [2/200], Batch Step [200/600], D_loss: 0.1136, G

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

# ...

# After the training loop

# Generate and visualize fake images
with torch.no_grad():
    num_samples = 10
    z = torch.randn(num_samples, 100)
    fake_images = generator(z).detach()

    fake_images = fake_images.view(-1, 28, 28)
    fake_images = fake_images * 0.5 + 0.5  # Denormalize the images
    fake_images = fake_images.numpy()

    fig, axes = plt.subplots(1, num_samples, figsize=(num_samples, 1))
    for i, ax in enumerate(axes):
        ax.axis('off')
        ax.imshow(fake_images[i], cmap='gray')

In [None]:
# # Save the GAN model
# torch.save({
#     'generator_state_dict'    : generator.state_dict(),
#     'discriminator_state_dict': discriminator.state_dict(),
#     'optimizer_G_state_dict'  : optimizer_G.state_dict(),
#     'optimizer_D_state_dict'  : optimizer_D.state_dict()
# }, 'gan_model_mnist.pth')

In [None]:
# Save the GAN model
torch.save( {
            'generator_state_dict':     generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            }, 
    'gan_model_mnist.pt')

In [None]:
# Load the GAN model
checkpoint = torch.load('gan_model_mnist.pth')
generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])