In [11]:
import torch
import torch.nn as nn

from torch.nn import functional as F

In [12]:
def get_loss(logits, labels):
    criterion = nn.BCEWithLogitsLoss()
    return criterion(logits.squeeze(1), labels)

In [13]:
def calculate_discriminator_loss(logits_real, logits_fake):    
    labels_real = torch.ones(logits_real.shape[0])
    labels_fake = torch.zeros(logits_fake.shape[0]) 
    
    loss_real = get_loss(logits_real, labels_real)
    loss_fake = get_loss(logits_fake, labels_fake)
    
    return loss_real + loss_fake

In [14]:
def calculate_generator_loss(logits_fake):
    labels_fake = torch.ones(logits_fake.shape[0])
    
    loss_fake = get_loss(logits_fake, labels_fake)
    
    return loss_fake

In [15]:
def get_noise(batch_size, noise_dim):
    return 2 * torch.rand((batch_size, noise_dim)) - 1

In [16]:
class Generator(nn.Module):
    def __init__(self, noise_dim, image_size):
        super().__init__()
        
        self.fc1 = nn.Linear(noise_dim, 1024)
        self.relu1 = nn.ReLU()
        
        self.fc2 = nn.Linear(1024, 1024)
        self.relu2 = nn.ReLU()
        
        self.fc3 = nn.Linear(1024, image_size)
        self.tanh = nn.Tanh()
        
    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        
        return self.tanh(self.fc3(x))

In [17]:
class Discriminator(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        
        self.fc1 = nn.Linear(input_size, 256)
        self.leaky_relu1 = nn.LeakyReLU(0.01)
        
        self.fc2 = nn.Linear(256, 256)
        self.leaky_relu2 = nn.LeakyReLU(0.01)
        
        self.fc3 = nn.Linear(256, 1)
        
    def forward(self, x):
        x = self.leaky_relu1(self.fc1(x))
        x = self.leaky_relu2(self.fc2(x))
        
        return self.fc3(x)

In [18]:
image_size = 1*28*28
batch_size = 1

noise_dim = 128

In [21]:
def test_models():
    generator = Generator(noise_dim, image_size)
    discriminator = Discriminator(image_size)
    
    noise = get_noise(batch_size, noise_dim)

    fake_image_1 = generator(noise)
    fake_image_2 = generator(noise)

    output_1 = discriminator(fake_image_1)
    output_2 = discriminator(fake_image_2)

    discriminator_loss = calculate_discriminator_loss(output_1, output_2)
    generator_loss = calculate_generator_loss(output_2)

    print("Discriminator loss: {} | Generator loss: {}".format(discriminator_loss, generator_loss))
    
test_models()

Discriminator loss: 1.3865917921066284 | Generator loss: 0.6760484576225281


In [25]:
generator = Generator(noise_dim, image_size)
discriminator = Discriminator(image_size)

generator_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-3, betas=(0.5, 0.999))
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-3, betas=(0.5, 0.999))

In [26]:
num_iters = 5

for i in range(num_iters):
    real_images = torch.randn(batch_size, image_size) # For study purposes    
    
    logits_real = discriminator(real_images)
        
    noise = get_noise(batch_size, noise_dim)
    fake_images = generator(noise)
    logits_fake = discriminator(fake_images)

    discriminator_loss = calculate_discriminator_loss(logits_real, logits_fake)
    
    discriminator_optimizer.zero_grad()
    discriminator_loss.backward()
    discriminator_optimizer.step()
    
    noise = get_noise(batch_size, noise_dim)
    fake_images = generator(noise)
    generator_logits_fake = discriminator(fake_images)
    
    generator_loss = calculate_generator_loss(generator_logits_fake)

    generator_optimizer.zero_grad()
    generator_loss.backward()
    generator_optimizer.step()
    
    print(f"Epoch {i+1}/{num_iters} | Generator loss: {generator_loss} | Discriminator loss: {discriminator_loss}")


Epoch 1/5 | Generator loss: 0.7187022566795349 | Discriminator loss: 1.3558391332626343
Epoch 2/5 | Generator loss: 0.692669153213501 | Discriminator loss: 1.356467366218567
Epoch 3/5 | Generator loss: 0.6644152998924255 | Discriminator loss: 1.4128437042236328
Epoch 4/5 | Generator loss: 0.6127821207046509 | Discriminator loss: 1.541396141052246
Epoch 5/5 | Generator loss: 0.5826628804206848 | Discriminator loss: 1.5353467464447021
