In [26]:
"""Imports."""

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
            

In [None]:
"""MNIST GAN."""

# Define the generator network
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        
        # Create coordinate grid
        x = np.linspace(-1, 1, img_shape[1])
        y = np.linspace(-1, 1, img_shape[2])
        x_grid, y_grid = np.meshgrid(x, y)
        x_grid = torch.tensor(x_grid, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        y_grid = torch.tensor(y_grid, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        self._coordinates = torch.cat((x_grid, y_grid), dim=1)
        
        # Create the model
        self.model = nn.Sequential(
            nn.Conv2d(latent_dim + 2, 32, kernel_size=3, stride=1, padding="same"),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding="same"),
            nn.ReLU(),
            nn.Conv2d(32, img_shape[0], kernel_size=3, stride=1, padding="same"),
            nn.Tanh()
        )
        self.img_shape = img_shape

    def forward(self, z):
        # Tile the latent vector to match the image shape
        z = z.view(z.size(0), -1, 1, 1)
        z = z.repeat(1, 1, self.img_shape[1], self.img_shape[2])
        
        # Concatenate the coordinates with the latent vector
        z = torch.cat((z, self._coordinates.repeat(z.size(0), 1, 1, 1)), dim=1)
        
        # Generate the image
        img = self.model(z)
        return img

# Define the discriminator network
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(img_shape[0], 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=0),
            nn.ReLU(),
        )
        self._linear = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 3 * 3, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        validity = self.model(img)
        # import pdb; pdb.set_trace()
        validity = self._linear(validity)
        return validity

# Load the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist_dataset, batch_size=64, shuffle=True)

# Define the hyperparameters
latent_dim = 10
img_shape = (1, 28, 28)
lr = 0.0002
epochs = 50

# Initialize the generator and discriminator networks
generator = Generator(latent_dim, img_shape)
discriminator = Discriminator(img_shape)

# Define the loss function and optimizers
loss_function = nn.BCELoss()
optimizer_generator = optim.Adam(generator.parameters(), lr=lr)
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=lr)

# Training loop
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):
        # Adversarial ground truths
        valid = torch.ones((imgs.size(0), 1))
        fake = torch.zeros((imgs.size(0), 1))

        # Train the discriminator
        optimizer_discriminator.zero_grad()
        real_loss = loss_function(discriminator(imgs), valid)
        z = torch.randn(imgs.shape[0], latent_dim)
        gen_imgs = generator(z)
        fake_loss = loss_function(discriminator(gen_imgs.detach()), fake)
        discriminator_loss = (real_loss + fake_loss) / 2
        discriminator_loss.backward()
        optimizer_discriminator.step()

        # Train the generator
        optimizer_generator.zero_grad()
        gen_imgs = generator(z)
        generator_loss = loss_function(discriminator(gen_imgs), valid)
        generator_loss.backward()
        optimizer_generator.step()

        # Print the progress
        if i % 100 == 0:
            print(f"Epoch [{epoch}/{epochs}] Batch [{i}/{len(dataloader)}] "
                  f"Discriminator Loss: {discriminator_loss.item():.4f} Generator Loss: {generator_loss.item():.4f}")
            
            # Plot generated images
            fig, axes = plt.subplots(5, 5, figsize=(10, 10))
            for j in range(25):
                ax = axes[j // 5, j % 5]
                ax.imshow(gen_imgs[j].detach().numpy().squeeze(), cmap='gray')
                ax.axis('off')
            plt.tight_layout()
            plt.show()
            

# Sample from the generator
z = torch.randn(25, latent_dim)
generated_images = generator(z)

> [0;32m/var/folders/39/y4130jp93gx8kmc6bw85pr1r0000gn/T/ipykernel_6930/3854986393.py[0m(17)[0;36m__init__[0;34m()[0m
[0;32m     15 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     16 [0;31m[0;34m[0m[0m
[0m[0;32m---> 17 [0;31m        self.model = nn.Sequential(
[0m[0;32m     18 [0;31m            [0mnn[0m[0;34m.[0m[0mConv2d[0m[0;34m([0m[0mlatent_dim[0m [0;34m+[0m [0;36m2[0m[0;34m,[0m [0;36m32[0m[0;34m,[0m [0mkernel_size[0m[0;34m=[0m[0;36m3[0m[0;34m,[0m [0mstride[0m[0;34m=[0m[0;36m1[0m[0;34m,[0m [0mpadding[0m[0;34m=[0m[0;34m"same"[0m[0;34m)[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     19 [0;31m            [0mnn[0m[0;34m.[0m[0mReLU[0m[0;34m([0m[0;34m)[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m
torch.Size([1, 2, 28, 28])
