# Conditional Generative Adversarial Network Tutorial

This tutorial will explain how to implement and train [Conditional GAN](https://arxiv.org/pdf/1411.1784.pdf) using PyTorch. I assume that the reader is already familiar with implementing the original GAN. I will build on PyTorch's official GAN implementation to create a Conditional GAN. CIFAR-10 dataset is used.

In short, Conditional GAN is just like the original GAN but it does not generate the output randomly. It conditions itself on a variable to generate an output. It means that we can manually make the model generate images from a certain class. For example, if a Conditional GAN is trained on CIFAR-10 images, we can make it generate only dog images.

## GAN Architecture

### Discriminator

<img src="images/disc_arch.png" alt="drawing" width="800"/>

### Generator

<img src="images/gen_arch.png" alt="drawing" width="800"/>

## Task

The architectures given above are from PyTorch's official GAN implementation for CIFAR-10 dataset. Our task is to convert it to a Conditional GAN (CGAN).

If you look at Equation 2 at [1], you can see that the value function differs from the original GAN's value function slightly. The authors introduced the <strong>y</strong> term in the CGAN. <strong>y</strong> is the extra information that we are conditioning our model on.

The following text is from the paper and it summarizes the importance of the <strong>y</strong> term.

"Generative adversarial nets can be extended to a conditional model if both the generator and discriminator are conditioned on some extra information y. y could be any kind of auxiliary information,
such as class labels or data from other modalities." [1]

So, we can transform the GAN implementation to a CGAN by introducing a <strong>y</strong> term both to the discriminator and the generator.

The rest of this tutorial differs from the CGAN paper because the paper generates MNIST images, not CIFAR-10.

## GAN Architecture Code

Let's look at the original GAN code.

### Discriminator Network

In [12]:
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
        
        def forward(self, x):
            return self.main(x)

### Generator Network

In [13]:
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )
        
        def forward(self, x):
            return self.main(x)

### Adding the Conditional Variable

The authors state that "the adversarial training framework allows for considerable flexibility in how this hidden representation is composed." [1]

My understanding from this is, the <strong>y</strong> term can be anything as long as the same <strong>y</strong> is used both for the generator and the discriminator.

Since we are using the CIFAR-10 dataset, we have 10 classes. Let's create an Embedding Matrix that maps the 10 classes to 100 dimensional vectors. Here, the <strong>y</strong> term is the 100-dimensional vector for each class.

This can be done easily using PyTorch.

In [15]:
embeddings = nn.Embedding(10,100)
embeddings.weight.requires_grad = False

The embedding matrix we have created should not be trained at all. The values that are generated when creating the matrix are not important. (I mean not important numerically.) They could be anything. The important thing is that the conditional variable is same throughout the training and also in the evaluation. So make sure that the weights of the embedding matrix are non-trainable by setting requires_grad to be False. 

Now that we have the <strong>y</strong> vector, how should we add it to the model?

To keep thing simple, I decided not to change the original GAN architecture's forward pass. So, for both the discriminator and the generator, I merged the input (input is noise for the generator and an image for the discriminator) with the conditional variable <strong>y</strong> and used the result as an input to the original GAN architecture. This still works because I make sure that the input dimensions are still same for the both models after merging. The merging steps for the discriminator and the generator are:

    __Generator__
    
    label_embed = embedding_layer(label) # [batchsize x 100]
    
    noise = torch.cat([noise, label_embed], dim=1) # [batchsize x 200]

    noise = linear_layer(noise) # [batchsize x 100]
    
    __Discriminator__
    
    label_embed = embedding_layer(label) # [batchsize x 100]

    label_map = linear_layer(label_embed) # [batchsize x 3072]

    label_map = label_map.view(-1,3,32,32) # [batchsize x 3 x 32 x 32]

    im = torch.cat([im,label_map], dim=1) # [batchsize x 6 x 32 x 32]

    out = conv_layer(x) # [batchsize x 3 x 32 x 32]
    
label_embed is the 100-dimensional vector (<strong>y</strong>). To make the input for both model same dimension respectively, I applied the above steps. So, the generator still takes a matrix of size [batchsize x 100] and the discriminator a matrix of size [batchsize x 3 x 32 x 32]

The discriminator and the generator codes look like this after making these changes.

### Discriminator Network

In [17]:
class Discriminator(nn.Module):
    def __init__(self, embeddings, nc=3, ndf=64):
        super(Discriminator, self).__init__()
        
        self.embeddings = embeddings
        self.label_to_image = nn.Linear(100,32*32*3)
        self.conv1 = nn.Conv2d(nc * 2, nc, 1, 1, 0, bias=False)
        
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
        
        def forward(self, x, label_embed):
            
            label_embed = self.embeddings(label_embed)

            label_map = self.label_to_image(label_embed)
            label_map = label_map.view(-1,3,32,32)

            x = torch.cat([x,label_map], dim=1)

            out = self.conv1(x)
            output = self.main(out)

            return output

### Generator Network

In [18]:
class Generator(nn.Module):
    def __init__(self, embeddings, nc=3, nz=100, ngf=64):
        super(Generator, self).__init__()
        
        self.embeddings = embeddings
        self.linear = nn.Linear(200,100)
        
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )
        
        def forward(self, x, label_embed):
            label_embed = self.embeddings(label_embed)

            x = x.view(-1,100)
            x = torch.cat([x,label_embed], dim=1)

            x = self.linear(x)
            x = x.unsqueeze(2).unsqueeze(3)

            output = self.main(x)
            return output

And it's done. We have implemented the CGAN. Now, we can train and look at the outputs.  

## Training

Training is same with the original GAN.

In [30]:
batchsize = 200
epochs = 500

train_data = Data(file_name="cifar-10-batches-py/")
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batchsize,shuffle=True,drop_last=True)


embeddings = nn.Embedding(10,100)
embeddings.weight.requires_grad = False

netD = Discriminator(embeddings)
netG = Generator(embeddings)


optimizerD = optim.Adam(netD.parameters(),lr=0.0002,betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(),lr=0.0002,betas=(0.5, 0.999))

netD.train()
netG.train()

nz = 100

criterion = nn.BCELoss()

real_label = torch.ones([batchsize,1], dtype=torch.float).to(device)
fake_label = torch.zeros([batchsize,1], dtype=torch.float).to(device)


for epoch in range(epochs):
    for i, (input_sequence, label) in enumerate(train_data_loader):
        
        fixed_noise = torch.randn(batchsize, nz, 1, 1, device=device)

        input_sequence = input_sequence.to(device)
        label_embed = label.to(device)
        
        '''
            Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        '''

        D_real_result = netD(input_sequence, label_embed)
        D_real_loss = criterion(D_real_result.view(batchsize,-1), real_label)

        G_result = netG(fixed_noise,label_embed)

        D_fake_result = netD(G_result,label_embed)

        D_fake_loss = criterion(D_fake_result.view(batchsize,-1), fake_label)

        # Back propagation
        D_train_loss = (D_real_loss + D_fake_loss) / 2

        netD.zero_grad()
        D_train_loss.backward()
        optimizerD.step()

        '''
            Update G network: maximize log(D(G(z)))
        '''
        new_label = torch.LongTensor(batchsize,10).random_(0, 10).to(device)
        new_embed = new_label[:,0].view(-1)

        G_result = netG(fixed_noise, new_embed)

        D_fake_result = netD(G_result, new_embed)
        G_train_loss = criterion(D_fake_result.view(batchsize,-1), real_label)


        # Back propagation
        netD.zero_grad()
        netG.zero_grad()
        G_train_loss.backward()
        optimizerG.step()
        
        print("D_loss:%f\tG_loss:%f" % (D_train_loss,G_train_loss))

## Generating Images

I trained the model with two different hyperparameter settings and below are the generated images. I conditioned each row on a specific class. Some images do not look very nice and there seems to be mode collapse but I think that if I had more resources and could train the model longer, the results would be better. Nevertheless, I think it is clear that the way this model uses the conditional variable works. Each row has images from a specific class.

<img src="images/results1.png" alt="drawing" width="400"/>

<img src="images/results2.png" alt="drawing" width="400"/>

# References

[1] Conditional Generative Adversarial Nets [arXiv](https://arxiv.org/pdf/1411.1784.pdf)