<a href="https://colab.research.google.com/github/omridrori/generative-models/blob/main/gan%20models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


class GANLossGenerator(nn.Module):


    def __init__(self) -> None:

        super(GANLossGenerator, self).__init__()

    def forward(self, discriminator_prediction_fake: torch.Tensor, **kwargs) -> torch.Tensor:

        # Loss can be computed by utilizing the softplus function since softplus combines both sigmoid and log
        return - F.softplus(discriminator_prediction_fake).mean()


class GANLossDiscriminator(nn.Module):


    def __init__(self) -> None:

        # Call super constructor
        super(GANLossDiscriminator, self).__init__()

    def forward(self, discriminator_prediction_real: torch.Tensor,
                discriminator_prediction_fake: torch.Tensor, **kwargs) -> torch.Tensor:

        # Loss can be computed by utilizing the softplus function since softplus combines both sigmoid and log
        return F.softplus(- discriminator_prediction_real).mean() \
               + F.softplus(discriminator_prediction_fake).mean()


class NSGANLossGenerator(nn.Module):


    def __init__(self) -> None:

        super(NSGANLossGenerator, self).__init__()

    def forward(self, discriminator_prediction_fake: torch.Tensor, **kwargs) -> torch.Tensor:

        # Loss can be computed by utilizing the softplus function since softplus combines both sigmoid and log
        return F.softplus(- discriminator_prediction_fake).mean()


class NSGANLossDiscriminator(GANLossDiscriminator):


    def __init__(self) -> None:

        super(NSGANLossDiscriminator, self).__init__()


class WassersteinGANLossGenerator(nn.Module):

    def __index__(self) -> None:

        # Call super constructor
        super(WassersteinGANLossGenerator, self).__index__()

    def forward(self, discriminator_prediction_fake: torch.Tensor, **kwargs) -> torch.Tensor:

        return - discriminator_prediction_fake.mean()


class WassersteinGANLossDiscriminator(nn.Module):

    def __init__(self) -> None:

        super(WassersteinGANLossDiscriminator, self).__init__()

    def forward(self, discriminator_prediction_real: torch.Tensor,
                discriminator_prediction_fake: torch.Tensor, **kwargs) -> torch.Tensor:

        return - discriminator_prediction_real.mean() \
               + discriminator_prediction_fake.mean()


class WassersteinGANLossGPGenerator(WassersteinGANLossGenerator):

    def __index__(self) -> None:

        super(WassersteinGANLossGPGenerator, self).__index__()


class WassersteinGANLossGPDiscriminator(nn.Module):



    def __init__(self) -> None:

        super(WassersteinGANLossGPDiscriminator, self).__init__()

    def forward(self, discriminator_prediction_real: torch.Tensor,
                discriminator_prediction_fake: torch.Tensor,
                discriminator: nn.Module,
                real_samples: torch.Tensor,
                fake_samples: torch.Tensor,
                lambda_gradient_penalty: Optional[float] = 2., **kwargs) -> torch.Tensor:

        # Generate random alpha for interpolation
        alpha = torch.rand((real_samples.shape[0], 1), device=real_samples.device)
        # Make interpolated samples
        samples_interpolated = (alpha * real_samples + (1. - alpha) * fake_samples)
        samples_interpolated.requires_grad = True
        # Make discriminator prediction
        discriminator_prediction_interpolated = discriminator(samples_interpolated)
        # Calc gradients
        gradients = torch.autograd.grad(outputs=discriminator_prediction_interpolated.sum(),
                                        inputs=samples_interpolated,
                                        create_graph=True,
                                        retain_graph=True)[0]
        # Calc gradient penalty
        gradient_penalty = (gradients.view(gradients.shape[0], -1).norm(dim=1) - 1.).pow(2).mean()
        return - discriminator_prediction_real.mean() \
               + discriminator_prediction_fake.mean() \
               + lambda_gradient_penalty * gradient_penalty


class LSGANLossGenerator(nn.Module):

    def __init__(self) -> None:

        super(LSGANLossGenerator, self).__init__()

    def forward(self, discriminator_prediction_fake: torch.Tensor, **kwargs) -> torch.Tensor:

        return - 0.5 * (discriminator_prediction_fake - 1.).pow(2).mean()


class LSGANLossDiscriminator(nn.Module):


    def __init__(self) -> None:

        # Call super constructor
        super(LSGANLossDiscriminator, self).__init__()

    def forward(self, discriminator_prediction_real: torch.Tensor,
                discriminator_prediction_fake: torch.Tensor, **kwargs) -> torch.Tensor:

        return 0.5 * ((- discriminator_prediction_real - 1.).pow(2).mean()
                      + discriminator_prediction_fake.pow(2).mean())


class HingeGANLossGenerator(WassersteinGANLossGenerator):


    def __init__(self) -> None:
        """
        Constructor method.
        """
        # Call super constructor
        super(HingeGANLossGenerator, self).__init__()


class HingeGANLossDiscriminator(nn.Module):

    def __init__(self) -> None:

        super(HingeGANLossDiscriminator, self).__init__()

    def forward(self, discriminator_prediction_real: torch.Tensor,
                discriminator_prediction_fake: torch.Tensor, **kwargs) -> torch.Tensor:

        return - torch.minimum(torch.tensor(0., dtype=torch.float, device=discriminator_prediction_real.device),
                               discriminator_prediction_real - 1.).mean() \
               - torch.minimum(torch.tensor(0., dtype=torch.float, device=discriminator_prediction_fake.device),
                               - discriminator_prediction_fake - 1.).mean()

In [None]:
from typing import Optional

import torch
import torch.nn as nn
from torch.nn.utils import spectral_norm
import numpy as np


def get_generator(latent_size: int) -> nn.Module:

    return nn.Sequential(nn.Linear(latent_size, 256, bias=True),
                         nn.LeakyReLU(),
                         nn.Linear(256, 256, bias=True),
                         nn.LeakyReLU(),
                         nn.Linear(256, 256, bias=True),
                         nn.LeakyReLU(),
                         nn.Linear(256, 256, bias=True),
                         nn.Tanh(),
                         nn.Linear(256, 2, bias=True))


def get_discriminator(use_spectral_norm: bool) -> nn.Module:

    if use_spectral_norm:
        return nn.Sequential(spectral_norm(nn.Linear(2, 256, bias=True)),
                             nn.LeakyReLU(),
                             spectral_norm(nn.Linear(256, 256, bias=True)),
                             nn.LeakyReLU(),
                             spectral_norm(nn.Linear(256, 256, bias=True)),
                             nn.LeakyReLU(),
                             spectral_norm(nn.Linear(256, 256, bias=True)),
                             nn.LeakyReLU(),
                             spectral_norm(nn.Linear(256, 1, bias=True)))
    return nn.Sequential(nn.Linear(2, 256, bias=True),
                         nn.LeakyReLU(),
                         nn.Linear(256, 256, bias=True),
                         nn.LeakyReLU(),
                         nn.Linear(256, 256, bias=True),
                         nn.LeakyReLU(),
                         nn.Linear(256, 256, bias=True),
                         nn.LeakyReLU(),
                         nn.Linear(256, 1, bias=True))


def get_data(samples: Optional[int] = 400, variance: Optional[float] = 0.05) -> torch.Tensor:

    assert samples % 8 == 0 and samples > 0, "Number of samples must be a multiple of 8 and bigger than 0"
    # Init angels of the means
    angels = torch.cumsum((2 * np.pi / 8) * torch.ones((8)), dim=0)
    # Convert angles to 2D coordinates
    means = torch.stack([torch.cos(angels), torch.sin(angels)], dim=0)
    # Generate data
    data = torch.empty((2, samples))
    counter = 0
    for gaussian in range(means.shape[1]):
        for sample in range(int(samples / 8)):
            data[:, counter] = torch.normal(means[:, gaussian], variance)
            counter += 1
    # Reshape data
    data = data.T
    # Shuffle data
    data = data[torch.randperm(data.shape[0])]
    # Convert numpy array to tensor
    return data.float()

In [None]:
device='cuda'
epochs=500
d_updates=1
plot_frequency=10
lr=0.0001
latent_size=32
samples=10000
batch_size=500
loss='hinge'#['standard', 'non-saturating', 'hinge', 'wasserstein', 'wasserstein-gp', 'least-squares']
spectral_norm='false'
clip_weights=1
topk=True

In [None]:
# Get arguments

import torch
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as plt
import os



if __name__ == '__main__':
    # Make directory to save plots
    path = os.path.join(os.getcwd(), 'plots', loss + ("_top_k" if topk else "") + ("_sn" if spectral_norm else "") + ("_clip" if clip_weights else ""))
    os.makedirs(path, exist_ok=True)
    # Init hyperparameters
    fixed_generator_noise: torch.Tensor = torch.randn([samples // 10, latent_size], device=device)
    # Get data
    data: torch.Tensor = get_data(samples=samples).to(device)
    # Get generator
    generator: nn.Module = get_generator(latent_size=latent_size)
    # Get discriminator
    discriminator: nn.Module = get_discriminator(use_spectral_norm=spectral_norm)
    # Init Loss function
    if loss == 'standard':
        loss_generator: nn.Module = GANLossGenerator()
        loss_discriminator: nn.Module = GANLossDiscriminator()
    elif loss == 'non-saturating':
        loss_generator: nn.Module = NSGANLossGenerator()
        loss_discriminator: nn.Module = NSGANLossDiscriminator()
    elif loss == 'hinge':
        loss_generator: nn.Module = HingeGANLossGenerator()
        loss_discriminator: nn.Module = HingeGANLossDiscriminator()
    elif loss == 'wasserstein':
        loss_generator: nn.Module = WassersteinGANLossGenerator()
        loss_discriminator: nn.Module = WassersteinGANLossDiscriminator()
    elif loss == 'wasserstein-gp':
        loss_generator: nn.Module = WassersteinGANLossGPGenerator()
        loss_discriminator: nn.Module = WassersteinGANLossGPDiscriminator()
    else:
        loss_generator: nn.Module = LSGANLossGenerator()
        loss_discriminator: nn.Module = LSGANLossDiscriminator()
    # Networks to train mode
    generator.train()
    discriminator.train()
    # Models to device
    generator.to(device)
    discriminator.to(device)
    # Init optimizer
    generator_optimizer: torch.optim.Optimizer = torch.optim.RMSprop(generator.parameters(), lr=lr)
    discriminator_optimizer: torch.optim.Optimizer = torch.optim.RMSprop(discriminator.parameters(), lr=lr)
    # Init progress bar
    progress_bar = tqdm(total=epochs)
    # Training loop
    for epoch in range(epochs):  # type: int
        # Update progress bar
        progress_bar.update(n=1)
        # Update discriminator more often than generator to train it till optimality and get more reliable gradients of Wasserstein
        for _ in range(d_updates):  # type: int
            # Shuffle data
            data = data[torch.randperm(data.shape[0], device=device)]
            for index in range(0, samples, batch_size):  # type:int
                # Get batch
                batch: torch.Tensor = data[index:index + batch_size]
                # Get noise for generator
                noise: torch.Tensor = torch.randn([batch_size, latent_size], device=device)
                # Optimize discriminator
                discriminator_optimizer.zero_grad()
                generator_optimizer.zero_grad()
                with torch.no_grad():
                    fake_samples: torch.Tensor = generator(noise)
                prediction_real: torch.Tensor = discriminator(batch)
                prediction_fake: torch.Tensor = discriminator(fake_samples)
                if isinstance(loss_discriminator, WassersteinGANLossGPDiscriminator):
                    loss_d: torch.Tensor = loss_discriminator(prediction_real, prediction_fake, discriminator, batch,
                                                            fake_samples)
                else:
                    loss_d: torch.Tensor = loss_discriminator(prediction_real, prediction_fake)
                loss_d.backward()
                discriminator_optimizer.step()

                # Clip weights to enforce Lipschitz constraint as proposed in Wasserstein GAN paper
                if clip_weights > 0:
                    with torch.no_grad():
                        for param in discriminator.parameters():
                            param.clamp_(-clip_weights, clip_weights)

            # Get noise for generator
            noise: torch.Tensor = torch.randn([batch_size, latent_size], device=device)
            # Optimize generator
            discriminator_optimizer.zero_grad()
            generator_optimizer.zero_grad()
            fake_samples: torch.Tensor = generator(noise)
            prediction_fake: torch.Tensor = discriminator(fake_samples)
            if topk and (epoch >= 0.5 * epochs):
                prediction_fake = torch.topk(input=prediction_fake[:, 0], k=prediction_fake.shape[0] // 2)[0]
            loss_g: torch.Tensor = loss_generator(prediction_fake)
            loss_g.backward()
            generator_optimizer.step()
            # Update progress bar description
            progress_bar.set_description(
                'Epoch {}, Generator loss {:.4f}, Discriminator loss {:.4f}'.format(epoch, loss_g.item(),
                                                                                    loss_d.item()))
        # Plot samples of generator
        if ((epoch + 1) % plot_frequency) == 0:
            generator.eval()
            generator_samples = generator(fixed_generator_noise)
            generator_samples = generator_samples.cpu().detach().numpy()
            plt.scatter(data[::10, 0].cpu(), data[::10, 1].cpu(), color='blue', label='Samples from $p_{data}$', s=2, alpha=0.5)
            plt.scatter(generator_samples[:, 0], generator_samples[:, 1], color='red',
                        label='Samples from generator $G$', s=2, alpha=0.5)
            plt.legend(loc=1)
            plt.title('Step {}'.format((epoch + 1) * samples // batch_size))
            plt.xlim((-1.5, 1.5))
            plt.ylim((-1.5, 1.75))
            plt.grid()
            plt.savefig(os.path.join(path, '{}.png'.format(str(epoch + 1).zfill(4))))
            plt.close()
            generator.train()

Epoch 499, Generator loss -0.8292, Discriminator loss 1.9206: 100%|██████████| 500/500 [01:42<00:00,  5.28it/s]