# 1) Import

In [None]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision import datasets as dset
import matplotlib.pyplot as plt
torch.manual_seed(0)

<torch._C.Generator at 0x7fa8c9812a50>

# 2) CelebA data

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Root directory for the dataset
data_root = '/content/drive/MyDrive/dataset'
# Spatial size of training images, images are resized to this size.
image_size = 64
batch_size = 128

celeba_data = dset.CelebA(data_root,
                          download = True,
                          transform = transforms.Compose([
                              transforms.Resize(image_size),
                              transforms.CenterCrop(image_size),
                              transforms.ToTensor(),
                              transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
                          ]))
                          
dataloader = torch.utils.data.DataLoader(celeba_data, batch_size = batch_size, shuffle = True)

Using downloaded and verified file: /content/drive/MyDrive/dataset/celeba/img_align_celeba.zip
Using downloaded and verified file: /content/drive/MyDrive/dataset/celeba/list_attr_celeba.txt
Using downloaded and verified file: /content/drive/MyDrive/dataset/celeba/identity_CelebA.txt
Using downloaded and verified file: /content/drive/MyDrive/dataset/celeba/list_bbox_celeba.txt
Using downloaded and verified file: /content/drive/MyDrive/dataset/celeba/list_landmarks_align_celeba.txt
Using downloaded and verified file: /content/drive/MyDrive/dataset/celeba/list_eval_partition.txt


# 3) Generator

In [None]:
class Generator(nn.Module):
    def __init__ (self, input_dim=10, im_chan=3, hidden_dim=32):
        super(Generator, self).__init__()
        self.input_dim = input_dim
        # Build the neural network
        self.gen = nn.Sequential(
            self.gen_block(input_dim, hidden_dim * 8, stride = 1, padding = 0),
            self.gen_block(hidden_dim * 8, hidden_dim * 4),
            self.gen_block(hidden_dim * 4, hidden_dim * 2),
            self.gen_block(hidden_dim * 2, hidden_dim),
            self.gen_block(hidden_dim, im_chan, final_layer = True)
        )

    # def gen_block
    def gen_block(self, input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, dilation = 1, output_padding = 0, final_layer = False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size = kernel_size, stride = stride, padding = padding, dilation = dilation, output_padding = output_padding),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace = True)
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size = kernel_size, stride = stride, padding = padding, dilation = dilation, output_padding = output_padding),
                nn.Tanh()
            )

    def unsqueeze_noise(self, noise):
        return noise.view(len(noise), self.input_dim, 1, 1)
            
    def forward(self, noise):
        x = self.unsqueeze_noise(noise)
        return self.gen(x)

# 4) Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__ (self, im_chan=13, hidden_dim=16):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.disc_block(im_chan, hidden_dim),
            self.disc_block(hidden_dim, hidden_dim * 2),
            self.disc_block(hidden_dim * 2, hidden_dim * 4),
            self.disc_block(hidden_dim * 4, hidden_dim * 8),
            self.disc_block(hidden_dim * 8, 1, final_layer = True)
        )

    # def disc_block
    def disc_block(self, input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, final_layer = False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True)
            )
        else: # Final Layer
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride = 1, padding = 0)
            )
    
    def forward(self, image):
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)

# 5) Noise

In [None]:
def get_noise(n_samples, input_dim, device = 'cuda'):
    return torch.randn(n_samples, input_dim, device = device)

# 6) Class

In [None]:
# one-hot vector
import torch.nn.functional as F
def get_one_hot_labels(labels, n_classes):
  	return F.one_hot(labels, n_classes)
   
# latent vector
def combine_vectors(x, y):
  combined = torch.cat((x.float(), y.float()), 1)
  return combined

# 7) Initialization

In [None]:
celeba_shape = (3, 64, 64)
n_classes = 2

criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.0002
device = 'cuda'

# 8) Optimizer

In [None]:
def get_input_dimensions(z_dim, celeba_shape, n_classes):
    generator_input_dim = z_dim + n_classes
    discriminator_im_chan = celeba_shape[0] + n_classes

    return generator_input_dim, discriminator_im_chan

generator_input_dim, discriminator_im_chan = get_input_dimensions(z_dim, celeba_shape, n_classes)

gen = Generator(input_dim = generator_input_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr = lr)
disc = Discriminator(im_chan = discriminator_im_chan).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr = lr)

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

# 9) Image display

In [None]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 64, 64), nrow = 5, show = True):
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow = nrow)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    if show:
        plt.show()

# 10) Training

In [None]:
cur_step = 0
generator_losses = []
discriminator_losses = []

noise_and_labels = False
fake = False

fake_image_and_labels = False
real_image_and_labels = False
disc_fake_pred = False
disc_real_pred = False

for epoch in range(n_epochs):
    for real, tlabels in tqdm(dataloader):
        #print(tlabels.shape)
        labels = tlabels[:, 21]

        cur_batch_size = len(real)
        real = real.to(device)

        one_hot_labels = get_one_hot_labels(labels.to(device), n_classes)

        image_one_hot_labels = one_hot_labels[:, :, None, None]
        
        image_one_hot_labels = image_one_hot_labels.repeat(1, 1, celeba_shape[1], celeba_shape[2])

        # Update disc
        disc_opt.zero_grad()
        fake_noise = get_noise(cur_batch_size, z_dim, device = device)

        noise_and_labels = combine_vectors(fake_noise, one_hot_labels)
        fake = gen(noise_and_labels)

        fake_image_and_labels = combine_vectors(fake, image_one_hot_labels)  # G(z|y)
        real_image_and_labels = combine_vectors(real, image_one_hot_labels)  # x|y

        disc_fake_pred = disc(fake_image_and_labels.detach())  # D(G(z|y))
        disc_real_pred = disc(real_image_and_labels)  # D(x|y)

        disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
        disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
        disc_loss = (disc_fake_loss + disc_real_loss) / 2  # disc의 loss
        disc_loss.backward(retain_graph = True)
        disc_opt.step()

        discriminator_losses += [disc_loss.item()]

        # Update gen
        gen_opt.zero_grad()

        fake_image_and_labels = combine_vectors(fake, image_one_hot_labels)

        disc_fake_pred = disc(fake_image_and_labels)  # D(G(z|y))
        gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
        gen_loss.backward()
        gen_opt.step()

        generator_losses += [gen_loss.item()]

    # visualize the results
        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(generator_losses[-display_step:]) / display_step
            disc_mean = sum(discriminator_losses[-display_step:]) / display_step
            print(f"Step {cur_step}: Generator loss: {gen_mean}, Discriminator loss: {disc_mean}")
            show_tensor_images(fake)
            show_tensor_images(real)
            step_bins = 20
            x_axis = sorted([i * step_bins for i in range(len(generator_losses) // step_bins)] * step_bins)
            num_examples = (len(generator_losses) // step_bins) * step_bins
            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label = "Generator loss"
            )
            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(discriminator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label = "Discriminator loss"
            )
            plt.legend()
            plt.show()
        elif cur_step == 0:
            print("Congraturations.")
        cur_step += 1

Output hidden; open in https://colab.research.google.com to view.