https://theaisummer.com/gan-computer-vision/

In [1]:
import torch

def ones_target(size):
    '''
    For real data when training D, while for fake data when training G
    Tensor containing ones, with shape = size
    '''
    return torch.ones(size, 1)

def zeros_target(size):
    '''
    For data when training D
    Tensor containing zeros, with shape = size
    '''
    return torch.zeros(size, 1)


def train_discriminator(discriminator, optimizer, real_data, fake_data, loss):
    N = real_data.size(0)
    # Reset gradients
    optimizer.zero_grad()
    # Train on Real Data
    prediction_real = discriminator(real_data)
    # Calculate error and backpropagate
    target_real = ones_target(N)

    error_real = loss(prediction_real, target_real)
    error_real.backward()

    # Train on Fake Data
    prediction_fake = discriminator(fake_data)
    # Calculate error and backpropagate
    target_fake = zeros_target(N)

    error_fake = loss(prediction_fake, target_fake)
    error_fake.backward()

    # Update weights with gradients
    optimizer.step()
    return error_real + error_fake, prediction_real, prediction_fake


def train_generator(discriminator, optimizer, fake_data, loss):
    N = fake_data.size(0)
    # Reset the gradients
    optimizer.zero_grad()
    # Sample noise and generate fake data
    prediction = discriminator(fake_data)

    # Calculate error and backpropagate
    target = ones_target(N)
    error = loss(prediction, target)
    error.backward()
    optimizer.step()
    return error

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Generator Network
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.main(x)

# Discriminator Network
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_size, hidden_size),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_size, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.main(x)

# Utility functions
def ones_target(size):
    return torch.ones(size)

def zeros_target(size):
    return torch.zeros(size)

def train_discriminator(discriminator, optimizer, real_data, fake_data, loss):
    N = real_data.size(0)
    optimizer.zero_grad()
    
    prediction_real = discriminator(real_data)
    target_real = ones_target(prediction_real.size(0))
    error_real = loss(prediction_real.squeeze(), target_real)
    error_real.backward()

    prediction_fake = discriminator(fake_data)
    target_fake = zeros_target(prediction_fake.size(0))
    error_fake = loss(prediction_fake.squeeze(), target_fake)
    error_fake.backward()

    optimizer.step()
    return error_real + error_fake

def train_generator(discriminator, optimizer, fake_data, loss):
    N = fake_data.size(0)
    optimizer.zero_grad()
    
    prediction = discriminator(fake_data)
    target = ones_target(prediction.size(0))
    error = loss(prediction.squeeze(), target)
    error.backward()
    
    optimizer.step()
    return error

# Hyperparameters
input_size = 100
hidden_size = 256
output_size = 784
batch_size = 64
epochs = 200

# Initialize networks and optimizers
generator = Generator(input_size, hidden_size, output_size)
discriminator = Discriminator(output_size, hidden_size)

g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)

loss = nn.BCELoss()

# Visualization function
def visualize_gan_results(generator, num_samples=25):
    generator.eval()
    with torch.no_grad():
        noise = torch.randn(num_samples, input_size)
        generated_samples = generator(noise).view(-1, 28, 28)
    
    fig, axes = plt.subplots(5, 5, figsize=(10, 10))
    for i, ax in enumerate(axes.flat):
        ax.imshow(generated_samples[i].numpy(), cmap='gray')
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# Training loop
for epoch in range(epochs):
    # Load real data
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    dataset = dsets.MNIST(root='./data', train=True, download=True, transform=transform)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    for real_batch, _ in dataloader:
        real_batch = real_batch.view(-1, 784)
        
        # Generate fake data
        noise = torch.randn(batch_size, input_size)
        fake_batch = generator(noise)
        
        # Train Discriminator
        d_error = train_discriminator(discriminator, d_optimizer, real_batch, fake_batch.detach(), loss)
        
        # Train Generator
        noise = torch.randn(batch_size, input_size)
        fake_batch = generator(noise)
        g_error = train_generator(discriminator, g_optimizer, fake_batch, loss)
    
    # Print progress
    if epoch % 10 == 0:
        print(f'Epoch [{epoch}/{epochs}], D Loss: {d_error.item()}, G Loss: {g_error.item()}')

# Visualize results
visualize_gan_results(generator)

Epoch [0/200], D Loss: 0.2032044678926468, G Loss: 4.607448577880859


In [None]:
import matplotlib.pyplot as plt

# After training, generate and visualize samples
def visualize_gan_results(generator, num_samples=25):
    # Generate samples
    with torch.no_grad():
        noise = torch.randn(num_samples, input_size)
        generated_samples = generator(noise).view(-1, 28, 28)
    
    # Plot samples
    fig, axes = plt.subplots(5, 5, figsize=(10, 10))
    for i, ax in enumerate(axes.flat):
        ax.imshow(generated_samples[i].numpy(), cmap='gray')
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# Call visualization after training
visualize_gan_results(generator)