# Conditonal GAN
A vanilla GAN is a class-agostic generative model, which means the samples it generates can be from any class. However, in many applications, we want to generate samples from a specific class. For example, we want to generate images of a specific digit, or we want to generate images of a specific person. In this case, we need to condition the GAN on the class label. This is called a conditional GAN.

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## 1. Add class infomration as learned embeddings
We can add class information to the GAN by concatenating the class embedding to the noise vector. The class embedding is a vector representation of the class label.

In [5]:
# example
batch_size = 32
emb = nn.Embedding(10, 5) # 10 classes, 5 dimensions for each class
labels = torch.randint(0, 10, (batch_size,)) # 32 labels
noise = torch.randn(batch_size, 100) # 32 latent vectors

emb_labels = emb(labels) # 32 label embeddings
print(emb_labels.shape) # torch.Size([32, 5])

gen_input = torch.cat((emb_labels, noise), -1) # 32 latent vectors concatenated with 32 label embeddings
print(gen_input.shape) # torch.Size([32, 105])

torch.Size([32, 5])
torch.Size([32, 105])


# 2. Add class information as one-hot encoded vector
We can also add class information to the GAN by concatenating the one-hot vector to the noise vector. The one-hot vector is a vector representation of the class label.

In [7]:
# example
batch_size = 32
num_classes = 10
labels = torch.randint(0, num_classes, (batch_size,)) # 32 labels
noise = torch.randn(batch_size, 100) # 32 latent vectors

# one-hot encoding
one_hot_labels = F.one_hot(labels, num_classes) # 32 one-hot encoded labels


In [12]:
print(labels[:5])
print(one_hot_labels[:5])

tensor([3, 2, 1, 4, 7])
tensor([[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0]])


In [22]:
# concatenate one-hot encoded labels with random noise
input = torch.cat((one_hot_labels.float(), noise), -1) # 32 latent vectors concatenated with 32 one-hot encoded labels
print(input.shape) # torch.Size([32, 110])

torch.Size([32, 110])
