# CGAN prac

In [None]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torchvision import datasets as dset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0)

<torch._C.Generator at 0x7fed5a3d28d0>

### Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        # Build the neural network
        self.gen = nn.Sequential(
            self.gen_block(z_dim, hidden_dim*4),
            self.gen_block(hidden_dim*4, hidden_dim*2, kernel_size=4, stride=1),
            self.gen_block(hidden_dim*2, hidden_dim),
            self.gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True)
        )

    def gen_block(self, in_channel, out_channel, kernel_size=3, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(inplace=True)
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channel, out_channel, kernel_size, stride),
                nn.Tanh()
            )

    def unsqueeze_noise(self, noise):
        return noise.view(len(noise), self.z_dim, 1, 1)

    def forward(self, noise):
        x = self.unsqueeze_noise(noise)
        result = self.gen(x)
        return result

### Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, im_chan=1, hidden_dim=16):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.disc_block(im_chan, hidden_dim),
            self.disc_block(hidden_dim, hidden_dim*2),
            self.disc_block(hidden_dim*2, 1, final_layer=True)
        )
    
    def disc_block(self, in_channel, out_channel, kernel_size=4, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(in_channel, out_channel, kernel_size, stride),
                # nn.BatchNorm2d(output_channels),
                nn.BatchNorm2d(out_channel),
                nn.LeakyReLU(0.2, inplace=True)
            )
        else:  # Final Layer
            return nn.Sequential(
                nn.Conv2d(in_channel, out_channel, kernel_size, stride)
            )

    def forward(self, image):
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)

### noise

In [None]:
def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples, z_dim, device=device)

### class

In [None]:
import torch.nn.functional as F
def get_one_hot_labels(labels, n_classes):
    return F.one_hot(labels, n_classes)

In [None]:
def combine_vectors(x, y):
    combined = torch.cat((x.float(), y.float()), 1)
    return combined

### hyper parameters setting

In [None]:
criterion = nn.BCEWithLogitsLoss()
mnist_shape = (1, 28, 28)
n_classes = 10
n_epochs = 100
z_dim = 64
display_step = 500
batch_size = 128
lr = 2e-4
device = 'cuda'

# beta_1 = 0.5
# beta_2 = 0.999

### Data loading

In [None]:
transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))
])

dataloader = DataLoader(
    MNIST('.', download=True, transform=transform),
    batch_size = batch_size,
    shuffle = True
)

### Optimizer

In [None]:
def get_input_dimensions(z_dim, mnist_shape, n_classes):
    generator_input_dim = z_dim + n_classes
    discriminator_im_chan = mnist_shape[0] + n_classes
    return generator_input_dim, discriminator_im_chan

In [None]:
generator_input_dim, discriminator_im_chan = get_input_dimensions(z_dim, mnist_shape, n_classes)

gen = Generator(z_dim=generator_input_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator(im_chan=discriminator_im_chan).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

## Image Display

In [None]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_tensor = (image_tensor+1)/2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

### Training

In [None]:
cur_step = 0
generator_losses = []
discriminator_losses = []

noise_and_labels = False
fake = False

fake_image_and_labels = False
real_image_and_labels = False
disc_fake_pred = False
disc_real_pred = False

for epoch in range(n_epochs):
    for real, labels in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.to(device)

        one_hot_labels = get_one_hot_labels(labels.to(device), n_classes)
        image_one_hot_labels = one_hot_labels[:, :, None, None]
        image_one_hot_labels = image_one_hot_labels.repeat(1, 1, mnist_shape[1], mnist_shape[2])

        # Update disc
        disc_opt.zero_grad()

        fake_noise = get_noise(cur_batch_size, z_dim, device=device)  # z

        noise_and_labels = combine_vectors(fake_noise, one_hot_labels)  # z | c
        fake = gen(noise_and_labels)   # G(z|c)

        fake_image_and_labels = combine_vectors(fake, image_one_hot_labels)
        real_image_and_labels = combine_vectors(real, image_one_hot_labels)

        disc_fake_pred = disc(fake_image_and_labels.detach()) # D(G(z|c)) , detach : 미분 X
        disc_real_pred = disc(real_image_and_labels)  # D(x|c)

        disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))  # log(1-D(G(z|c)))
        disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))  # log(D(x|c))
        disc_loss = (disc_fake_loss + disc_real_loss) / 2

        # track the average disc loss
        disc_loss.backward(retain_graph=True)
        disc_opt.step()

        discriminator_losses += [disc_loss.item()]

        # Update gen
        gen_opt.zero_grad()
        
        fake_image_and_labels = combine_vectors(fake, image_one_hot_labels) # G(z|c)
        disc_fake_pred = disc(fake_image_and_labels) # D(G(z|c))
        gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))  # log(1-D(G(z|c)))
        gen_loss.backward()
        gen_opt.step()

        # Track the average gen loss
        generator_losses += [gen_loss.item()]

        # Visualize the results
        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(generator_losses[-display_step:]) / display_step
            disc_mean = sum(discriminator_losses[-display_step:]) / display_step
            print(f"step {cur_step}: Generator loss : {gen_mean}, Discriminator loss: {disc_mean}")
            show_tensor_images(fake)
            show_tensor_images(real)
            step_bins = 20
            num_examples = (len(generator_losses) // step_bins) * step_bins

            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Generator loss"
            )

            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(discriminator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Discriminator loss"
            )

            plt.legend()
            plt.show()
            
        cur_step += 1

Output hidden; open in https://colab.research.google.com to view.

# celeba dataset - CGAN

### import

In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision import datasets as dset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
torch.manual_seed(0)

<torch._C.Generator at 0x7fb5d31bea70>

In [2]:
class Generator(nn.Module):
    def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        # Build the neural network
        self.gen = nn.Sequential(
            self.gen_block(z_dim, hidden_dim*4),
            self.gen_block(hidden_dim*4, hidden_dim*2, kernel_size=4, stride=1),
            self.gen_block(hidden_dim*2, hidden_dim),
            self.gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True)
        )

    def gen_block(self, in_channel, out_channel, kernel_size=3, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(inplace=True)
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channel, out_channel, kernel_size, stride),
                nn.Tanh()
            )

    def unsqueeze_noise(self, noise):
        return noise.view(len(noise), self.z_dim, 1, 1)

    def forward(self, noise):
        x = self.unsqueeze_noise(noise)
        result = self.gen(x)
        return result

### Discriminator

In [3]:
class Discriminator(nn.Module):
    def __init__(self, im_chan=1, hidden_dim=16):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.disc_block(im_chan, hidden_dim),
            self.disc_block(hidden_dim, hidden_dim*2),
            self.disc_block(hidden_dim*2, 1, final_layer=True)
        )
    
    def disc_block(self, in_channel, out_channel, kernel_size=4, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(in_channel, out_channel, kernel_size, stride),
                # nn.BatchNorm2d(output_channels),
                nn.BatchNorm2d(out_channel),
                nn.LeakyReLU(0.2, inplace=True)
            )
        else:  # Final Layer
            return nn.Sequential(
                nn.Conv2d(in_channel, out_channel, kernel_size, stride)
            )

    def forward(self, image):
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)

### noise

In [4]:
def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples, z_dim, device=device)

### class

In [5]:
import torch.nn.functional as F
def get_one_hot_labels(labels, n_classes):
    return F.one_hot(labels, n_classes)

In [6]:
def combine_vectors(x, y):
    combined = torch.cat((x.float(), y.float()), 1)
    return combined

### hyper parameters

In [7]:
n_epochs = 50
display_step = 500
lr = 2e-4

# Spatial size of training images, images are resized to this size.
image_size = 64
batch_size = 128

beta_1 = 0.5
beta_2 = 0.999
c_lambda = 10
crit_repeats = 5
device = 'cuda'

workers = 2 # Number of workers for dataloader
nc = 3 # Number of channels in the training images. For color images this is 3
nz = 100 # Size of z latent vector (i.e. size of generator input)
ngf = 64 # Size of feature maps in generator
ndf = 64 # Size of feature maps in discriminator
ngpu = 1 # Number of GPUs available. Use 0 for CPU mode.

### data loading

In [8]:
# Root directory for the dataset
data_root = "data"

# Data load
celeba_dataset = dset.CelebA(data_root,
                             download=True,
                             transform=transforms.Compose([
                                                           transforms.Resize(image_size),
                                                           transforms.CenterCrop(image_size),
                                                           transforms.ToTensor(),
                                                           transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                                                std=[0.5, 0.5, 0.5])
                                                           ]))

RuntimeError: ignored

In [None]:
import torchvision.utils as vutils

# Create the dataloader
dataloader = torch.utils.data.DataLoader(celeba_dataset,
                                         batch_size=batch_size,
                                         num_workers=workers,
                                         shuffle=True)

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

### Optimizer

In [None]:
def get_input_dimensions(z_dim, mnist_shape, n_classes):
    generator_input_dim = z_dim + n_classes
    discriminator_im_chan = mnist_shape[0] + n_classes
    return generator_input_dim, discriminator_im_chan

In [None]:
generator_input_dim, discriminator_im_chan = get_input_dimensions(z_dim, mnist_shape, n_classes)

gen = Generator(z_dim=generator_input_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator(im_chan=discriminator_im_chan).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

## Image Display

In [None]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_tensor = (image_tensor+1)/2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

### Training

In [None]:
cur_step = 0
generator_losses = []
discriminator_losses = []

noise_and_labels = False
fake = False

fake_image_and_labels = False
real_image_and_labels = False
disc_fake_pred = False
disc_real_pred = False

for epoch in range(n_epochs):
    for real, tlabels in tqdm(dataloader):
        
        labels = tlabels[:, 21]

        cur_batch_size = len(real)
        real = real.to(device)

        one_hot_labels = get_one_hot_labels(labels.to(device), n_classes)
        image_one_hot_labels = one_hot_labels[:, :, None, None]
        image_one_hot_labels = image_one_hot_labels.repeat(1, 1, mnist_shape[1], mnist_shape[2])

        # Update disc
        disc_opt.zero_grad()

        fake_noise = get_noise(cur_batch_size, z_dim, device=device)  # z

        noise_and_labels = combine_vectors(fake_noise, one_hot_labels)  # z | c
        fake = gen(noise_and_labels)   # G(z|c)

        fake_image_and_labels = combine_vectors(fake, image_one_hot_labels)
        real_image_and_labels = combine_vectors(real, image_one_hot_labels)

        disc_fake_pred = disc(fake_image_and_labels.detach()) # D(G(z|c)) , detach : 미분 X
        disc_real_pred = disc(real_image_and_labels)  # D(x|c)

        disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))  # log(1-D(G(z|c)))
        disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))  # log(D(x|c))
        disc_loss = (disc_fake_loss + disc_real_loss) / 2

        # track the average disc loss
        disc_loss.backward(retain_graph=True)
        disc_opt.step()

        discriminator_losses += [disc_loss.item()]

        # Update gen
        gen_opt.zero_grad()
        
        fake_image_and_labels = combine_vectors(fake, image_one_hot_labels) # G(z|c)
        disc_fake_pred = disc(fake_image_and_labels) # D(G(z|c))
        gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))  # log(1-D(G(z|c)))
        gen_loss.backward()
        gen_opt.step()

        # Track the average gen loss
        generator_losses += [gen_loss.item()]

        # Visualize the results
        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(generator_losses[-display_step:]) / display_step
            disc_mean = sum(discriminator_losses[-display_step:]) / display_step
            print(f"step {cur_step}: Generator loss : {gen_mean}, Discriminator loss: {disc_mean}")
            show_tensor_images(fake)
            show_tensor_images(real)
            step_bins = 20
            num_examples = (len(generator_losses) // step_bins) * step_bins

            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Generator loss"
            )

            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(discriminator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Discriminator loss"
            )

            plt.legend()
            plt.show()
            
        cur_step += 1