In [11]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from torchvision import utils
import os

In [12]:
batch_size = 128
image_size = 64
latent_dimension = 100
number_of_epochs = 200
learning_rate = 0.002
beta_01 = 0.5

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [13]:
transform = transforms.Compose([
    transforms.Resize(size=image_size),
    transforms.CenterCrop(size=image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]) # Normalizing to [-1, 1]
])
dataset_path = "./dataset/training-GAN/Test/"
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [14]:
class Generator(torch.nn.Module):
    def __init__(self, latent_dimension):
        super(Generator, self).__init__()
        
        # First layer
        self.conv_transpose_2d_01 = torch.nn.ConvTranspose2d(
            in_channels=latent_dimension, out_channels=512, kernel_size=4,
            stride=1, padding=0, bias=False
        )
        self.batch_norm_2d_01 = torch.nn.BatchNorm2d(num_features=512)
        self.relu_01 = torch.nn.ReLU(inplace=True)
        
        # Second layer
        self.conv_transpose_2d_02 = torch.nn.ConvTranspose2d(
            in_channels=512, out_channels=256, kernel_size=4,
            stride=2, padding=1, bias=False
        )
        self.batch_norm_2d_02 = torch.nn.BatchNorm2d(num_features=256)
        self.relu_02 = torch.nn.ReLU(inplace=True)
        
        # Third layer
        self.conv_transpose_2d_03 = torch.nn.ConvTranspose2d(
            in_channels=256, out_channels=128, kernel_size=4,
            stride=2, padding=1, bias=False
        )
        self.batch_norm_2d_03 = torch.nn.BatchNorm2d(num_features=128)
        self.relu_03 = torch.nn.ReLU(inplace=True)
        
        # Fourth layer
        self.conv_transpose_2d_04 = torch.nn.ConvTranspose2d(
            in_channels=128, out_channels=64, kernel_size=4,
            stride=2, padding=1, bias=False
        )
        self.batch_norm_2d_04 = torch.nn.BatchNorm2d(num_features=64)
        self.relu_04 = torch.nn.ReLU(inplace=True)
        
        # Fifth layer
        self.conv_transpose_2d_05 = torch.nn.ConvTranspose2d(
            in_channels=64, out_channels=3, kernel_size=4,
            stride=2, padding=1, bias=False
        )
        self.tanh = torch.nn.Tanh()
    
    def forward(self, x):
        x = self.conv_transpose_2d_01(x)
        x = self.batch_norm_2d_01(x)
        x = self.relu_01(x)
        
        x = self.conv_transpose_2d_02(x)
        x = self.batch_norm_2d_02(x)
        x = self.relu_02(x)
        
        x = self.conv_transpose_2d_03(x)
        x = self.batch_norm_2d_03(x)
        x = self.relu_03(x)
        
        x = self.conv_transpose_2d_04(x)
        x = self.batch_norm_2d_04(x)
        x = self.relu_04(x)
        
        x = self.conv_transpose_2d_05(x)
        x = self.tanh(x)
        
        return x

In [15]:
class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        # First layer
        self.conv2d_01 = torch.nn.Conv2d(
            in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False
        )
        self.leaky_relu_01 = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
        
        # Second layer
        self.conv2d_02 = torch.nn.Conv2d(
            in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False
        )
        self.batch_norm_2d_02 = torch.nn.BatchNorm2d(num_features=128)
        self.leaky_relu_02 = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
        
        # Third layer
        self.conv2d_03 = torch.nn.Conv2d(
            in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False
        )
        self.batch_norm_2d_03 = torch.nn.BatchNorm2d(num_features=256)
        self.leaky_relu_03 = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
        
        # Fourth layer
        self.conv2d_04 = torch.nn.Conv2d(
            in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False
        )
        self.batch_norm_2d_04 = torch.nn.BatchNorm2d(num_features=512)
        self.leaky_relu_04 = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
        
        # Fivth layer
        self.conv2d_05 = torch.nn.Conv2d(
            in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False
        )
        self.sigmoid_05 = torch.nn.Sigmoid()
        
    def forward(self, x):
        x = self.conv2d_01(x)
        x = self.leaky_relu_01(x)
        
        x = self.conv2d_02(x)
        x = self.batch_norm_2d_02(x)
        x = self.leaky_relu_02(x)
        
        x = self.conv2d_03(x)
        x = self.batch_norm_2d_03(x)
        x = self.leaky_relu_03(x)
        
        x = self.conv2d_04(x)
        x = self.batch_norm_2d_04(x)
        x = self.leaky_relu_04(x)
        
        x = self.conv2d_05(x)
        x = self.sigmoid_05(x)
        
        return x.view(-1)

In [16]:
generator_model = Generator(latent_dimension=latent_dimension).to(device)
print(f"Generator: {generator_model}\n")

discriminator_model = Discriminator().to(device)
print(f"Discriminator: {discriminator_model}")

criterion = torch.nn.BCELoss() # Binary cross entropy loss/ BCE
optimizer_generator = torch.optim.Adam(generator_model.parameters(), lr=learning_rate, betas=(beta_01, 0.999))
optimizer_discriminator = torch.optim.Adam(discriminator_model.parameters(), lr=learning_rate, betas=(beta_01, 0.999))

Generator: Generator(
  (conv_transpose_2d_01): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
  (batch_norm_2d_01): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu_01): ReLU(inplace=True)
  (conv_transpose_2d_02): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (batch_norm_2d_02): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu_02): ReLU(inplace=True)
  (conv_transpose_2d_03): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (batch_norm_2d_03): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu_03): ReLU(inplace=True)
  (conv_transpose_2d_04): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (batch_norm_2d_04): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu_04): ReLU(inplace=Tr

In [None]:
# Training
fixed_noise = torch.randn(64, latent_dimension, 1, 1, device=device)

for epoch in range(number_of_epochs):
    for i, (data, _) in enumerate(data_loader):
        
        # Updating discriminator.
        discriminator_model.zero_grad()  # Zeroing/clearing gradients.
        real_data = data.to(device)
        batch_size = real_data.size(0)
        label_real = torch.ones(batch_size, device=device)  # [1, 1, ..., batch_size]
        label_fake = torch.zeros(batch_size, device=device)  # [0, 0, ..., batch_size]
        
        # Discriminator loss on real data
        output_real = discriminator_model(real_data)  # y predicted
        loss_real = criterion(output_real, label_real)  # BCE - Binary Cross Entropy
        
        # Discriminator loss on fake data
        noise = torch.randn(batch_size, latent_dimension, 1, 1, device=device)
        fake_data = generator_model(noise)
        output_fake = discriminator_model(fake_data.detach())  # Detach to avoid backprop through generator
        loss_fake = criterion(output_fake, label_fake)
        
        # Total discriminator loss
        loss_discriminator = loss_real + loss_fake
        loss_discriminator.backward()  # Calculating gradients.
        optimizer_discriminator.step()  # Updating the weights by using gradients.
        
        # Updating generator.
        generator_model.zero_grad()  # Zeroing/clearing gradients.
        label_gen = torch.ones(batch_size, device=device)  # Trick discriminator into believing fakes are real
        output_fake = discriminator_model(fake_data)
        loss_generator = criterion(output_fake, label_gen)
        loss_generator.backward()  # Calculating and propagating gradients.
        optimizer_generator.step()
        
        # Logging progress
        if i % 50 == 0:
            print(f"Epoch [{epoch}/{number_of_epochs}] Batch {i}/{len(data_loader)} "
                  f"Loss D: {loss_discriminator:.4f}, Loss G: {loss_generator:.4f}")
    
    # Save generated images after each epoch
    with torch.no_grad():
        fake_images = generator_model(fixed_noise).detach().cpu()
        utils.save_image(fake_images, f"output_epoch_{epoch}.png", normalize=True)

Epoch [0/200] Batch 0/1 Loss D: 6.2317, Loss G: 18.8154
Epoch [1/200] Batch 0/1 Loss D: 8.3506, Loss G: 10.3048
Epoch [2/200] Batch 0/1 Loss D: 0.8120, Loss G: 5.8034
Epoch [3/200] Batch 0/1 Loss D: 1.1487, Loss G: 3.5989
Epoch [4/200] Batch 0/1 Loss D: 2.1630, Loss G: 3.5016
Epoch [5/200] Batch 0/1 Loss D: 1.5439, Loss G: 7.0314
Epoch [6/200] Batch 0/1 Loss D: 3.1593, Loss G: 1.5379
Epoch [7/200] Batch 0/1 Loss D: 1.3801, Loss G: 1.7186
Epoch [8/200] Batch 0/1 Loss D: 1.2217, Loss G: 4.3166
Epoch [9/200] Batch 0/1 Loss D: 1.7265, Loss G: 1.8303
Epoch [10/200] Batch 0/1 Loss D: 1.3054, Loss G: 4.3833
Epoch [11/200] Batch 0/1 Loss D: 0.9011, Loss G: 2.7761
Epoch [12/200] Batch 0/1 Loss D: 1.0390, Loss G: 4.5660
Epoch [13/200] Batch 0/1 Loss D: 0.9102, Loss G: 2.3344
Epoch [14/200] Batch 0/1 Loss D: 2.2780, Loss G: 6.7019
Epoch [15/200] Batch 0/1 Loss D: 3.3561, Loss G: 2.0557
Epoch [16/200] Batch 0/1 Loss D: 1.2099, Loss G: 2.0518
Epoch [17/200] Batch 0/1 Loss D: 1.4635, Loss G: 5.1419


In [38]:
generator_model.eval() # Activating evaluation mode to deactive dropout or batch normalization updates.
noise = torch.randn(1, latent_dimension, 1, 1, device=device) # Single image (batch size = 1)

with torch.no_grad(): # No gradient calculation is necessary/needed
    fake_image = generator_model(noise).detach().cpu()
    
utils.save_image(fake_image, "generated_image.png", normalize=True)