### Import libraries

In [31]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from tqdm.notebook import tqdm
import imageio
import time

### Configs

In [32]:
IMAGE_SIZE = 64
BATCH_SIZE = 512
NOISE_DIM = 100
EPOCHS = 50
lr = 0.0002

### Setups

In [33]:
# Set the device to CUDA (GPU) if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if not torch.cuda.is_available():
    print("CUDA not found. Training on CPU.")

# Create a directory to save the generated images and GIF
output_dir = 'gan_images'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

Using device: cuda


### Loding dataset and Creating DataLoader

In [34]:
class CustomImageDataset(Dataset):
    """A custom dataset to load images from a single folder."""
    def __init__(self, folder_path, transform=None):
        self.image_paths = glob.glob(os.path.join(folder_path, '*.jpg'))
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

# Define the image transformations
transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Create the Dataset
dataset = CustomImageDataset(folder_path='img_align_celeba' , transform=transform)
print(f"Found {len(dataset)} images.")

# Create the DataLoader. num_workers=0 is for stability on Windows.
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
print("DataLoader is ready.")

Found 202599 images.
DataLoader is ready.


### Building Generator

In [35]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # Input is Z (noise vector), going into a transposed convolution
            nn.ConvTranspose2d(NOISE_DIM, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # State size: (512) x 4 x 4
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # State size: (256) x 8 x 8
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # State size: (128) x 16 x 16
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # State size: (64) x 32 x 32
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
            # Final state size: (3) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

### Building Discriminator

In [36]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # Input size: (3) x 64 x 64
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # State size: (64) x 32 x 32
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # State size: (128) x 16 x 16
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # State size: (256) x 8 x 8
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # State size: (512) x 4 x 4
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

### Models and Optimizers

In [37]:
# Creating the models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# We use Binary Cross Entropy as Loss function
criterion = nn.BCELoss()

# Fixed noise for visualizing generator's progress during training
fixed_noise = torch.randn(64, NOISE_DIM, 1, 1, device=device)

# Labels for real and fake images
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both networks
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))

print("Models and optimizers are ready.")

Models and optimizers are ready.


### Model training

In [38]:
for epoch in range(EPOCHS):
    # Wrap the dataloader with tqdm for a live progress bar
    batch_pbar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{EPOCHS}")
    
    for i, real_images in enumerate(batch_pbar):
        # Update Discriminator
        discriminator.zero_grad()
        
        # Train with real images
        real_images = real_images.to(device)
        b_size = real_images.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        output = discriminator(real_images).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        
        # Train with fake images
        noise = torch.randn(b_size, NOISE_DIM, 1, 1, device=device)
        fake_images = generator(noise)
        label.fill_(fake_label)
        output = discriminator(fake_images.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        
        errD = errD_real + errD_fake
        optimizer_D.step()

        # Update Generator
        generator.zero_grad()
        label.fill_(real_label)
        output = discriminator(fake_images).view(-1)
        errG = criterion(output, label)
        errG.backward()
        
        optimizer_G.step()
        
        # Update the progress bar with current loss values
        batch_pbar.set_postfix(Loss_D=f"{errD.item():.4f}", Loss_G=f"{errG.item():.4f}")

    # After each epoch, we save a grid of generated images
    with torch.no_grad():
        fake_samples = generator(fixed_noise).detach().cpu()
    save_image(fake_samples, f'{output_dir}/generated_sample_epoch_{epoch+1}.png', normalize=True)

print("Training Finished.")

Epoch 1/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 2/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 3/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 4/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 5/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 6/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 7/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 8/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 9/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 10/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 11/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 12/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 13/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 14/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 15/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 16/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 17/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 18/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 19/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 20/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 21/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 22/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 23/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 24/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 25/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 26/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 27/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 28/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 29/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 30/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 31/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 32/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 33/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 34/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 35/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 36/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 37/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 38/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 39/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 40/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 41/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 42/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 43/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 44/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 45/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 46/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 47/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 48/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 49/50:   0%|          | 0/396 [00:00<?, ?it/s]

Epoch 50/50:   0%|          | 0/396 [00:00<?, ?it/s]

Training Finished.


### Generate gif from generated photos

In [40]:
anim_file = 'dcgan_final.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
    filenames = glob.glob(f'{output_dir}/generated_samples*.png')
    filenames = sorted(filenames)
    for filename in tqdm(filenames, desc="Creating GIF"):
        image = imageio.imread(filename)
        writer.append_data(image)

print("GIF saved:", {anim_file})

Creating GIF: 0it [00:00, ?it/s]

GIF saved: {'dcgan_final.gif'}
