In [None]:
!pip install labml_nn

In [None]:
import math
from pathlib import Path
from typing import Iterator, Tuple
import numpy as np
import os
import torch
import torch.utils.data
import torchvision
from PIL import Image

from torchvision import datasets, transforms, utils

from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
from labml_nn.utils import cycle_dataloader

In [None]:
from tqdm.notebook import tqdm

In [None]:
class Dataset(torch.utils.data.Dataset):
    """
    ## Dataset

    This loads the training dataset and resize it to the give image size.
    """

    def __init__(self, path: str, image_size: int):
        """
        * `path` path to the folder containing the images
        * `image_size` size of the image
        """
        super().__init__()

        # Get the paths of all `jpg` files
        self.paths = [p for p in Path(path).glob(f'**/*.jpg')]

        # Transformation
        self.transform = torchvision.transforms.Compose([
            # Resize the image
            torchvision.transforms.Resize((image_size,image_size)),
            # Convert to PyTorch tensor
            torchvision.transforms.ToTensor(),
        ])

    def __len__(self):
        """Number of images"""
        return len(self.paths)

    def __getitem__(self, index):
        """Get the the `index`-th image"""
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

In [None]:
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# <a id="dataset_path"></a>
# We trained this on [CelebA-HQ dataset](https://github.com/tkarras/progressive_growing_of_gans).
# You can find the download instruction in this
# [discussion on fast.ai](https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3).
# Save the images inside `data/stylegan` folder.
dataset_path: str = os.path.join('data_faces','img_align_celeba')

# Batch size
batch_size: int = 32
# Dimensionality of $z$ and $w$
d_latent: int = 512
# Height/width of the image
image_size: int = 64
# Number of layers in the mapping network
mapping_network_layers: int = 8


In [None]:
# [Gradient Penalty Regularization Loss](index.html#gradient_penalty)
gradient_penalty = GradientPenalty()
# Gradient penalty coefficient $\gamma$
gradient_penalty_coefficient: float = 10.

# [Path length penalty](index.html#path_length_penalty)
path_length_penalty: PathLengthPenalty

In [None]:
### Initialize

# Create dataset
dataset = Dataset(dataset_path,image_size)
# Create data loader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,shuffle=True, drop_last=True)
# Continuous [cyclic loader](../../utils.html#cycle_dataloader)
loader = cycle_dataloader(dataloader)

In [None]:
# $\log_2$ of image resolution
log_resolution = int(math.log2(image_size))

# Create discriminator and generator
discriminator = Discriminator(log_resolution).to(device)
generator = Generator(log_resolution, d_latent).to(device)
# Get number of generator blocks for creating style and noise inputs
n_gen_blocks = generator.n_blocks
# Create mapping network
mapping_network = MappingNetwork(d_latent, mapping_network_layers).to(device)
# Create path length penalty loss
path_length_penalty = PathLengthPenalty(0.99).to(device)


In [None]:
# Generator & Discriminator learning rate
learning_rate: float = 1e-3
# Mapping network learning rate ($100 \times$ lower than the others)
mapping_network_learning_rate: float = 1e-5
# Number of steps to accumulate gradients on. Use this to increase the effective batch size.
gradient_accumulate_steps: int = 1
# $\beta_1$ and $\beta_2$ for Adam optimizer
adam_betas: Tuple[float, float] = (0.0, 0.99)
# Probability of mixing styles
style_mixing_prob: float = 0.9

In [None]:
# Discriminator and generator losses
discriminator_loss = DiscriminatorLoss().to(device)
generator_loss = GeneratorLoss().to(device)

# Create optimizers
discriminator_optimizer = torch.optim.Adam( discriminator.parameters(), lr=learning_rate, betas=adam_betas )
generator_optimizer = torch.optim.Adam( generator.parameters(), lr=learning_rate, betas=adam_betas )
mapping_network_optimizer = torch.optim.Adam(mapping_network.parameters(), lr=mapping_network_learning_rate, betas=adam_betas )

In [None]:
# The interval at which to compute gradient penalty
lazy_gradient_penalty_interval: int = 4
# Path length penalty calculation interval
lazy_path_penalty_interval: int = 32
# Skip calculating path length penalty during the initial phase of training
lazy_path_penalty_after: int = 5_000

# How often to log generated images
log_generated_interval: int = 500
# How often to save model checkpoints
save_checkpoint_interval: int = 2_000

In [None]:
if not os.path.exists("checkpoints"):
    os.makedirs("checkpoints")

In [None]:
def get_w( batch_size: int):
        """
        ### Sample $w$

        This samples $z$ randomly and get $w$ from the mapping network.

        We also apply style mixing sometimes where we generate two latent variables
        $z_1$ and $z_2$ and get corresponding $w_1$ and $w_2$.
        Then we randomly sample a cross-over point and apply $w_1$ to
        the generator blocks before the cross-over point and
        $w_2$ to the blocks after.
        """

        # Mix styles
        if torch.rand(()).item() < style_mixing_prob:
            # Random cross-over point
            cross_over_point = int(torch.rand(()).item() * n_gen_blocks)
            # Sample $z_1$ and $z_2$
            z2 = torch.randn(batch_size, d_latent).to(device)
            z1 = torch.randn(batch_size, d_latent).to(device)
            # Get $w_1$ and $w_2$
            w1 = mapping_network(z1)
            w2 = mapping_network(z2)
            # Expand $w_1$ and $w_2$ for the generator blocks and concatenate
            w1 = w1[None, :, :].expand(cross_over_point, -1, -1)
            w2 = w2[None, :, :].expand(n_gen_blocks - cross_over_point, -1, -1)
            return torch.cat((w1, w2), dim=0)
        # Without mixing
        else:
            # Sample $z$ and $z$
            z = torch.randn(batch_size, d_latent).to(device)
            # Get $w$ and $w$
            w = mapping_network(z)
            # Expand $w$ for the generator blocks
            return w[None, :, :].expand(n_gen_blocks, -1, -1)

def get_noise( batch_size: int):
        """
        ### Generate noise

        This generates noise for each [generator block](index.html#generator_block)
        """
        # List to store noise
        noise = []
        # Noise resolution starts from $4$
        resolution = 4

        # Generate noise for each generator block
        for i in range(n_gen_blocks):
            # The first block has only one $3 \times 3$ convolution
            if i == 0:
                n1 = None
            # Generate noise to add after the first convolution layer
            else:
                n1 = torch.randn(batch_size, 1, resolution, resolution, device=device)
            # Generate noise to add after the second convolution layer
            n2 = torch.randn(batch_size, 1, resolution, resolution, device=device)

            # Add noise tensors to the list
            noise.append((n1, n2))

            # Next block has $2 \times$ resolution
            resolution *= 2

        # Return noise tensors
        return noise

def generate_images( batch_size: int):
        """
        ### Generate images

        This generate images using the generator
        """

        # Get $w$
        w = get_w(batch_size)
        # Get noise
        noise = get_noise(batch_size)

        # Generate images
        images = generator(w, noise)

        # Return images and $w$
        return images, w


'   \ndef step( idx: int):\n        """\n        ### Training Step\n        """\n\n        # Train the discriminator\n        with monit.section(\'Discriminator\'):\n            # Reset gradients\n            discriminator_optimizer.zero_grad()\n\n            # Accumulate gradients for `gradient_accumulate_steps`\n            for i in range(gradient_accumulate_steps):\n               \n                    # Sample images from generator\n                    generated_images, _ = generate_images(batch_size)\n                    # Discriminator classification for generated images\n                    fake_output = discriminator(generated_images.detach())\n\n                    # Get real images from the data loader\n                    real_images = next(loader).to(device)\n                    # We need to calculate gradients w.r.t. real images for gradient penalty\n                    if (idx + 1) % lazy_gradient_penalty_interval == 0:\n                        real_images.requires_grad_(

In [None]:
def step( idx: int):
        """
        ### Training Step
        """

        # Reset gradients
        discriminator_optimizer.zero_grad()

        # Accumulate gradients for `gradient_accumulate_steps`
        for i in range(gradient_accumulate_steps):

                    # Sample images from generator
                    generated_images, _ = generate_images(batch_size)
                    # Discriminator classification for generated images
                    fake_output = discriminator(generated_images.detach())

                    # Get real images from the data loader
                    real_images = next(loader).to(device)
                    # We need to calculate gradients w.r.t. real images for gradient penalty
                    if (idx + 1) % lazy_gradient_penalty_interval == 0:
                        real_images.requires_grad_()
                    # Discriminator classification for real images
                    real_output = discriminator(real_images)

                    # Get discriminator loss
                    real_loss, fake_loss = discriminator_loss(real_output, fake_output)
                    disc_loss = real_loss + fake_loss

                    # Add gradient penalty
                    if (idx + 1) % lazy_gradient_penalty_interval == 0:
                        # Calculate and log gradient penalty
                        gp = gradient_penalty(real_images, real_output)
                        # Multiply by coefficient and add gradient penalty
                        disc_loss = disc_loss + 0.5 * gradient_penalty_coefficient * gp * lazy_gradient_penalty_interval

                    # Compute gradients
                    disc_loss.backward()





        # Clip gradients for stabilization
        torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
        # Take optimizer step
        discriminator_optimizer.step()

        # Reset gradients
        generator_optimizer.zero_grad()
        mapping_network_optimizer.zero_grad()

        # Accumulate gradients for `gradient_accumulate_steps`
        for i in range(gradient_accumulate_steps):
                # Sample images from generator
                generated_images, w = generate_images(batch_size)
                # Discriminator classification for generated images
                fake_output = discriminator(generated_images)

                # Get generator loss
                gen_loss = generator_loss(fake_output)

                # Add path length penalty
                if idx > lazy_path_penalty_after and (idx + 1) % lazy_path_penalty_interval == 0:
                    # Calculate path length penalty
                    plp = path_length_penalty(w, generated_images)
                    # Ignore if `nan`
                    if not torch.isnan(plp):
                        gen_loss = gen_loss + plp

                # Calculate gradients
                gen_loss.backward()





        # Clip gradients for stabilization
        torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
        torch.nn.utils.clip_grad_norm_(mapping_network.parameters(), max_norm=1.0)

        # Take optimizer step
        generator_optimizer.step()
        mapping_network_optimizer.step()

        utils.save_image(
                torch.cat([generated_images[:6], real_images[:3]], dim=0),
                os.path.join('checkpoints','sample.png'),
                nrow=3,
                normalize=True,
                value_range=(-1, 1),
            )



        # Save model checkpoints
        if (idx + 1) % save_checkpoint_interval == 0:
            torch.save(generator.state_dict(), os.path.join('checkpoints','generator.pth'))
            torch.save(mapping_network.state_dict(), os.path.join('checkpoints','mapping_network.pth'))
            torch.save(discriminator.state_dict(), os.path.join('checkpoints','discriminator.pth'))

In [None]:
# Total number of training steps
training_steps: int = 150_000
# Loop for `training_steps`
for i in tqdm(range(training_steps)):

            step(i)



  0%|          | 0/150000 [00:00<?, ?it/s]