In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
import kagglehub
dataexplorerx_img_align_celeba_path = kagglehub.dataset_download('dataexplorerx/img-align-celeba')

print('Data source import complete.')


In [None]:
import glob
import random
import os
import numpy as np
import torch

from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms


class FaceAttributeDataset(Dataset):
    def __init__(self, dataset_path, transform_ops=None, data_mode="train", selected_attributes=None):
        self.image_transform = transforms.Compose(transform_ops)

        self.target_attributes = selected_attributes
        self.image_files = sorted(glob.glob("%s/*.jpg" % dataset_path))
        self.image_files = self.image_files[:-2000] if data_mode == "train" else self.image_files[-2000:]
        self.annotation_file_path = glob.glob("%s/*.txt" % dataset_path)[0]
        self.labels = self.load_labels()

    def load_labels(self):
        """Loads label data for each image in the dataset"""
        label_dict = {}
        lines = [line.rstrip() for line in open(self.annotation_file_path, "r")]
        self.attribute_names = lines[1].split()
        for _, line in enumerate(lines[2:]):
            img_name, *attribute_values = line.split()
            img_labels = []
            for attribute in self.target_attributes:
                attr_index = self.attribute_names.index(attribute)
                img_labels.append(1 * (attribute_values[attr_index] == "1"))
            label_dict[img_name] = img_labels
        return label_dict

    def __getitem__(self, idx):
        image_path = self.image_files[idx % len(self.image_files)]
        img_name = image_path.split("/")[-1]
        img = self.image_transform(Image.open(image_path))
        labels = self.labels[img_name]
        labels = torch.FloatTensor(np.array(labels))

        return img, labels

    def __len__(self):
        return len(self.image_files)


In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch


def init_weights_normal(layer):
    layer_name = layer.__class__.__name__
    if layer_name.find("Conv") != -1:
        torch.nn.init.normal_(layer.weight.data, 0.0, 0.02)

class ResBlock(nn.Module):
    def __init__(self, input_channels):
        super(ResBlock, self).__init__()

        conv_block = [
            nn.Conv2d(input_channels, input_channels, 3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(input_channels, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(input_channels, input_channels, 3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(input_channels, affine=True, track_running_stats=True),
        ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)


class ResNetGenerator(nn.Module):
    def __init__(self, image_shape=(3, 128, 128), num_res_blocks=9, class_dim=5):
        super(ResNetGenerator, self).__init__()
        channels, img_height, _ = image_shape

        model = [
            nn.Conv2d(channels + class_dim, 64, 7, stride=1, padding=3, bias=False),
            nn.InstanceNorm2d(64, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
        ]

        curr_dim = 64
        for _ in range(2):
            model += [
                nn.Conv2d(curr_dim, curr_dim * 2, 4, stride=2, padding=1, bias=False),
                nn.InstanceNorm2d(curr_dim * 2, affine=True, track_running_stats=True),
                nn.ReLU(inplace=True),
            ]
            curr_dim *= 2

        for _ in range(num_res_blocks):
            model += [ResBlock(curr_dim)]

        # Upsampling
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(curr_dim, curr_dim // 2, 4, stride=2, padding=1, bias=False),
                nn.InstanceNorm2d(curr_dim // 2, affine=True, track_running_stats=True),
                nn.ReLU(inplace=True),
            ]
            curr_dim = curr_dim // 2

        # Output layer
        model += [nn.Conv2d(curr_dim, channels, 7, stride=1, padding=3), nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input_img, input_class):
        input_class = input_class.view(input_class.size(0), input_class.size(1), 1, 1)
        input_class = input_class.repeat(1, 1, input_img.size(2), input_img.size(3))
        input_img = torch.cat((input_img, input_class), 1)
        return self.model(input_img)


class ImageDiscriminator(nn.Module):
    def __init__(self, image_shape=(3, 128, 128), class_dim=5, num_layers=6):
        super(ImageDiscriminator, self).__init__()
        channels, img_height, _ = image_shape

        def disc_block(in_features, out_features):
            layers = [nn.Conv2d(in_features, out_features, 4, stride=2, padding=1), nn.LeakyReLU(0.01)]
            return layers

        layers = disc_block(channels, 64)
        curr_dim = 64
        for _ in range(num_layers - 1):
            layers.extend(disc_block(curr_dim, curr_dim * 2))
            curr_dim *= 2

        self.model = nn.Sequential(*layers)

        self.adv_output = nn.Conv2d(curr_dim, 1, 3, padding=1, bias=False)
        kernel_size = img_height // 2 ** num_layers
        self.class_output = nn.Conv2d(curr_dim, class_dim, kernel_size, bias=False)

    def forward(self, image):
        features = self.model(image)
        output_adv = self.adv_output(features)
        output_class = self.class_output(features)
        return output_adv, output_class.view(output_class.size(0), -1)


In [None]:
# Import necessary libraries
import os
import numpy as np
import math
import itertools
import time
import datetime
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.autograd as autograd
from PIL import Image
wq
os.makedirs("images", exist_ok=True)
os.makedirs("saved_models", exist_ok=True)

class Config:
    epoch = 0
    n_epochs = 200
    dataset_name = "img_align_celeba"
    batch_size = 16
    lr = 0.0002
    b1 = 0.5
    b2 = 0.999
    decay_epoch = 100
    n_cpu = 8
    img_height = 128
    img_width = 128
    channels = 3
    sample_interval = 400
    checkpoint_interval = -1
    residual_blocks = 6
    selected_attrs = ["Black_Hair", "Blond_Hair", "Brown_Hair", "Male", "Young"]
    n_critic = 5

config = Config()
print(config)

c_dim = len(config.selected_attrs)
image_shape = (config.channels, config.img_height, config.img_width)
cuda = torch.cuda.is_available()

criterion_cycle = torch.nn.L1Loss()
def criterion_cls(logit, target):
    return F.binary_cross_entropy_with_logits(logit, target, reduction='sum') / logit.size(0)

lambda_cls = 1
lambda_rec = 10
lambda_gp = 10

generator = GeneratorResNet(img_shape=image_shape, res_blocks=config.residual_blocks, c_dim=c_dim)
discriminator = Discriminator(img_shape=image_shape, c_dim=c_dim)

if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion_cycle = criterion_cycle.cuda()

if config.epoch != 0:
    generator.load_state_dict(torch.load(f"saved_models/generator_{config.epoch}.pth"))
    discriminator.load_state_dict(torch.load(f"saved_models/discriminator_{config.epoch}.pth"))
else:
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

optimizer_G = torch.optim.Adam(generator.parameters(), lr=config.lr, betas=(config.b1, config.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=config.lr, betas=(config.b1, config.b2))

train_transforms = [
    transforms.Resize(int(1.12 * config.img_height), Image.BICUBIC),
    transforms.RandomCrop(config.img_height),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

train_dataloader = DataLoader(
    CelebADataset(
        "/kaggle/input/img-align-celeba/img_align_celeba/", transforms_=train_transforms, mode="train", attributes=config.selected_attrs
    ),
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.n_cpu,
)

val_transforms = [
    transforms.Resize((config.img_height, config.img_width), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

validation_dataloader = DataLoader(
    CelebADataset(
        "/kaggle/input/img-align-celeba/img_align_celeba/", transforms_=val_transforms, mode="val", attributes=config.selected_attrs
    ),
    batch_size=10,
    shuffle=True,
    num_workers=1,
)

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

def compute_gradient_penalty(D, real_samples, fake_samples):
    alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates, _ = D(interpolates)
    fake = Variable(Tensor(np.ones(d_interpolates.shape)), requires_grad=False)
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

# Label modifications
label_changes = [
    ((0, 1), (1, 0), (2, 0)),
    ((0, 0), (1, 1), (2, 0)),
    ((0, 0), (1, 0), (2, 1)),
    ((3, -1),),
    ((4, -1),),
]

def sample_images(global_step):
    val_imgs, val_labels = next(iter(validation_dataloader))
    val_imgs = Variable(val_imgs.type(Tensor))
    val_labels = Variable(val_labels.type(Tensor))
    img_samples = None
    for i in range(10):
        img, label = val_imgs[i], val_labels[i]
        imgs = img.repeat(c_dim, 1, 1, 1)
        labels = label.repeat(c_dim, 1)
        for sample_i, changes in enumerate(label_changes):
            for col, val in changes:
                labels[sample_i, col] = 1 - labels[sample_i, col] if val == -1 else val
        gen_imgs = generator(imgs, labels)
        gen_imgs = torch.cat([x for x in gen_imgs.data], -1)
        img_sample = torch.cat((img.data, gen_imgs), -1)
        img_samples = img_sample if img_samples is None else torch.cat((img_samples, img_sample), -2)
    save_image(img_samples.view(1, *img_samples.shape), f"images/{global_step}.png", normalize=True)

start_time = time.time()
for epoch in range(config.epoch, config.n_epochs):
    for i, (real_images, labels) in enumerate(train_dataloader):
        real_images = Variable(real_images.type(Tensor))
        labels = Variable(labels.type(Tensor))
        sampled_c = Variable(Tensor(np.random.randint(0, 2, (real_images.size(0), c_dim))))
        generated_images = generator(real_images, sampled_c)

        optimizer_D.zero_grad()
        real_validity, pred_cls = discriminator(real_images)
        fake_validity, _ = discriminator(generated_images.detach())
        gradient_penalty = compute_gradient_penalty(discriminator, real_images.data, generated_images.data)
        loss_D_adv = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
        loss_D_cls = criterion_cls(pred_cls, labels)
        loss_D = loss_D_adv + lambda_cls * loss_D_cls
        loss_D.backward()
        optimizer_D.step()

        optimizer_G.zero_grad()
        if i % config.n_critic == 0:
            gen_imgs = generator(real_images, sampled_c)
            recov_imgs = generator(gen_imgs, labels)
            fake_validity, pred_cls = discriminator(gen_imgs)
            loss_G_adv = -torch.mean(fake_validity)
            loss_G_cls = criterion_cls(pred_cls, sampled_c)
            loss_G_rec = criterion_cycle(recov_imgs, real_images)
            loss_G = loss_G_adv + lambda_cls * loss_G_cls + lambda_rec * loss_G_rec
            loss_G.backward()
            optimizer_G.step()

            global_step = epoch * len(train_dataloader) + i
            time_left = datetime.timedelta(seconds=(config.n_epochs * len(train_dataloader) - global_step) * (time.time() - start_time) / (global_step + 1))

            print(f"[Epoch {epoch}/{config.n_epochs}] [Batch {i}/{len(train_dataloader)}] [D adv: {loss_D_adv.item()}, aux: {loss_D_cls.item()}] [G loss: {loss_G.item()}, adv: {loss_G_adv.item()}, aux: {loss_G_cls.item()}, cycle: {loss_G_rec.item()}] ETA: {time_left}")

            if global_step % config.sample_interval == 0:
                sample_images(global_step)

    if config.checkpoint_interval != -1 and epoch % config.checkpoint_interval == 0:
        torch.save(generator.state_dict(), f"saved_models/generator_{epoch}.pth")
        torch.save(discriminator.state_dict(), f"saved_models/discriminator_{epoch}.pth")
