In [None]:
import os
import time
import torch
import numpy as np
from torch import nn
from torch.functional import F
from matplotlib import pyplot as plt
from torch.autograd import grad
from torch.distributions.uniform import Uniform
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.io import read_image, ImageReadMode
from torchvision.transforms import v2
from torchvision.utils import make_grid
from torchinfo import summary
from rich.progress import Progress
from rich.console import Console
from rich.table import Table

In [None]:
class CelebADataset(Dataset):
    def __init__(self, img_dir, label_path, transform=None):
        self.img_dir = img_dir
        self.label_path = label_path
        self.transform = transform
        self.total_images = len(os.listdir(path=self.img_dir))
        self.img_paths = [os.path.join(self.img_dir, i) for i in os.listdir(path=self.img_dir)]
        self.label_file_lines = open(self.label_path, 'r').readlines()[2:]

    def __len__(self):
        return self.total_images

    def __getitem__(self, idx):
        img = read_image(path=self.img_paths[idx], mode=ImageReadMode.UNCHANGED)
        lbl = [float(x) for x in self.label_file_lines[idx].strip().replace(' ', ',').split(',')[1:] if x]
        lbl = [0 if x == -1 else 1 for x in lbl]
        lbl = torch.FloatTensor(lbl)
        
        if self.transform:
            img = self.transform(img)

        return img, lbl

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.rb = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=in_channels),
        )

    def forward(self, x):
        y = self.rb(x)
        y = y + x
        
        return F.relu(y)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(kernel_size=4, in_channels=43, out_channels=128, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2), 
            
            nn.Conv2d(kernel_size=4, in_channels=128, out_channels=256, stride=2, padding=1, bias=False), 
            nn.LeakyReLU(0.2),

            ResidualBlock(in_channels=256),
            
            nn.Conv2d(kernel_size=4, in_channels=256, out_channels=512, stride=2, padding=1, bias=False), 
            nn.LeakyReLU(0.2),

            ResidualBlock(in_channels=512),

            nn.Conv2d(kernel_size=4, in_channels=512, out_channels=1024, stride=2, padding=1, bias=False), 
            nn.LeakyReLU(0.2),

            nn.Flatten(),
            nn.Linear(in_features=1024 * 4 * 4, out_features=1, bias=False),
        )

    def forward(self, x):
        return self.main(x)

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim

        self.fc = nn.Sequential(
            nn.Linear(in_features=self.latent_dim, out_features=1024 * 4 * 4),
            nn.BatchNorm1d(1024 * 4 * 4),
            nn.ReLU(inplace=True),
            nn.Unflatten(dim=1, unflattened_size=(1024, 4, 4))
        )

        self.gen_0 = nn.Sequential(           
            nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=512, momentum=0.9),
            nn.ReLU(),
            ResidualBlock(in_channels=512)
        )
    
        self.gen_1 = nn.Sequential(    
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=256, momentum=0.9),
            nn.ReLU(),
            ResidualBlock(in_channels=256)
        )
        
        self.gen_2 = nn.Sequential(           
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=128, momentum=0.9),
            nn.ReLU(),
            ResidualBlock(in_channels=128)
        )
        
        self.gen_3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=3, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        x = self.gen_0(x)
        x = self.gen_1(x)
        x = self.gen_2(x)
        x = self.gen_3(x)
        
        return x

In [None]:
learning_rate_generator = 0.0002
learning_rate_discriminator = 0.0002
beta_1 = 0.5
beta_2 = 0.999
batch_size = 64
epochs = 25
gp_weight = 10
critic_steps = 3
latent_dim = 100
label_dim = 40

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.cuda.set_device(0)

_ = torch.cuda.current_device()

In [None]:
discriminator = Discriminator()
discriminator.to(device)

summary(discriminator, input_size=[(1, 43, 64, 64)])

In [None]:
generator = Generator(latent_dim=latent_dim + label_dim)
generator.to(device)

summary(generator, input_size=[(1, latent_dim + label_dim)])

In [None]:
discriminator_optimizer = torch.optim.Adam(params=discriminator.parameters(), lr=learning_rate_discriminator, betas=(beta_1, beta_2))
generator_optimizer = torch.optim.Adam(params=generator.parameters(), lr=learning_rate_generator, betas=(beta_1, beta_2))

In [None]:
class WassersteinLoss(nn.Module):
    def __init__(self):
        super(WassersteinLoss, self).__init__()

    def forward(self, fake_score, real_score=None):
        if real_score != None:
            return -1 * (torch.mean(real_score) - torch.mean(fake_score))
        else:
            return -1 * torch.mean(fake_score)

In [None]:
class GradientPenalty(nn.Module):
    def __init__(self, discriminator, gp_weight):
        super(GradientPenalty, self).__init__()
        self.discriminator = discriminator
        self.gp_weight = gp_weight

    def forward(self, fake_img, real_img):
        batch = fake_img.shape[0]
        alpha = torch.randn((batch, 1, 1, 1)).to(device)

        inter = (alpha * fake_img) + ((1 - alpha) * real_img).requires_grad_(True)
        score = self.discriminator(inter)
        
        grads = grad(
            outputs=score, 
            inputs=inter, 
            grad_outputs=torch.ones(score.size()).to(device), 
            create_graph=True, 
            retain_graph=True
        )[0]

        norms = torch.norm(grads.view(batch, -1), dim=1)

        return torch.mean(torch.pow((norms - 1), 2)) * self.gp_weight

In [None]:
discriminator_loss = WassersteinLoss()
generator_loss = WassersteinLoss()
gradient_penalty = GradientPenalty(discriminator=discriminator, gp_weight=gp_weight)

In [None]:
transform = v2.Compose(
    [
        v2.ToImage(),
        v2.ToDtype(torch.float, scale=True),
        v2.Normalize(mean=[0.5], std=[0.5]),
        v2.Resize((64, 64)),
    ]
)

train_dataset = CelebADataset('dataset/celeba/img_align/', label_path='dataset/celeba/list_attr_celeba.txt', transform=transform)
train_data_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

In [None]:
def train_loop(dataloader, 
               discriminator, 
               generator, 
               discriminator_optimizer, 
               generator_optimizer, 
               discriminator_loss,
               generator_loss,
               gradient_penalty,
               current_epoch,
               total_epochs,
               critic_steps,
               device):
    start_time = time.time()
    
    total_batches = len(dataloader)
    discriminator_total_loss = 0
    generator_total_loss = 0

    with Progress() as progress:
        total_batches = progress.add_task(f'[red]Epoch {current_epoch}/{epochs}', total=total_batches)

        while not progress.finished:
            for current_batch, (current_batch_data, current_batch_label) in enumerate(dataloader):
                current_batch_size = current_batch_data.shape[0]
                batch_size = current_batch_size
                
                # ======================
                # Discriminator Training
                # ======================
                discriminator.train()

                for _ in range(critic_steps):
                    discriminator_optimizer.zero_grad()
                    
                    current_batch_real_data = current_batch_data.to(device)
                    current_batch_label = current_batch_label.to(device)
                    current_batch_label_repeated = current_batch_label.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 64, 64)
                    disc_input_real = torch.cat((current_batch_real_data, current_batch_label_repeated), dim=1)

                    random_latent_vectors = torch.randn(size=(batch_size, latent_dim)).to(device)
                    random_latent_vectors_with_labels = torch.cat((random_latent_vectors, current_batch_label), dim=1)
                    current_batch_fake_data = generator(random_latent_vectors_with_labels)
                    disc_input_fake = torch.cat((current_batch_fake_data, current_batch_label_repeated), dim=1)
            
                    real_preds = discriminator(disc_input_real)
                    fake_preds = discriminator(disc_input_fake.detach())
            
                    disc_loss = discriminator_loss(fake_preds, real_preds)
                    grad_pena = gradient_penalty(disc_input_fake.detach(), disc_input_real)
                    w_loss_gp = disc_loss + grad_pena
        
                    w_loss_gp.backward()
                    discriminator_optimizer.step()
                    
                    discriminator_total_loss += w_loss_gp.item()
        
                # ==================
                # Generator Training
                # ==================
                generator.train()
                generator_optimizer.zero_grad()
        
                fake_preds_gen = discriminator(disc_input_fake)
        
                gen_loss = generator_loss(fake_preds_gen)
        
                gen_loss.backward()
                generator_optimizer.step()
        
                generator_total_loss += gen_loss.item()

                progress.update(total_batches, advance=1)
    
    avg_disc_loss = discriminator_total_loss / (len(dataloader) * critic_steps)
    avg_genr_loss = generator_total_loss / len(dataloader)

    end_time = time.time()

    table = Table(title=f'Epoch {current_epoch} Summary')
    
    table.add_column("Discriminator Loss", justify="left", style="cyan", no_wrap=True)
    table.add_column("Generator Loss", justify="left", style="cyan", no_wrap=True)
    table.add_column("Time Taken (seconds)", justify="left", style="cyan", no_wrap=True)

    table.add_row(f'{avg_disc_loss}', f'{avg_genr_loss}', f'{end_time - start_time}')
    
    console = Console()
    console.print(table)

In [None]:
def plot_and_save_generated_images_from_noise(n, generator, current_epoch):
    generator.eval()
    with torch.no_grad():
        z = torch.randn(size=(n, latent_dim)).to(device)
        l = torch.Tensor(
            [
                [
                    0., 1., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 1., 1., 0., 1.
                ]
            ]
        ).repeat(n, 1)
        
        l = l.to(device)
        z = torch.cat((z, l), dim=1)
        generated = generator(z).detach().cpu()
        grid = make_grid(generated, nrow=4, normalize=True)
        plt.imshow(np.transpose(grid, (1, 2, 0)))
        plt.axis("off")
        plt.savefig(f'epoch-{current_epoch}.png', bbox_inches='tight')
        plt.show()

In [None]:
for i in range(epochs):
    train_loop(dataloader=train_data_loader, 
               discriminator=discriminator, 
               generator=generator, 
               discriminator_optimizer=discriminator_optimizer,
               generator_optimizer=generator_optimizer,
               discriminator_loss=discriminator_loss,
               generator_loss=generator_loss, 
               gradient_penalty=gradient_penalty,
               current_epoch=i + 1,
               total_epochs=epochs,
               critic_steps=critic_steps,
               device=device)
    
    
    plot_and_save_generated_images_from_noise(n=16, generator=generator, current_epoch=i + 1)

print("Training complete!")

In [None]:
def plot_generated_images_from_noise(n, generator):
    generator.eval()
    with torch.no_grad():
        z = torch.randn(size=(n, latent_dim)).to(device)
        l = torch.Tensor(
            [
                [
                    0., 1., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 1., 1., 0., 1.
                ]
            ]
        ).repeat(n, 1)
        
        l = l.to(device)
        z = torch.cat((z, l), dim=1)
        generated = generator(z).detach().cpu()
        grid = make_grid(generated, nrow=4, normalize=True)
        plt.imshow(np.transpose(grid, (1, 2, 0)))
        plt.axis("off")
        plt.show()

In [None]:
plot_generated_images_from_noise(n=20, generator=generator)