## Condition GAN

### Chuẩn bị data

In [1]:
# import các thu viện cần thiết
import torch
from torch import nn, Tensor
import numpy as np
from torchvision.utils import save_image

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
image_size = 32
num_classes = 10
transforms = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset = datasets.MNIST(root='data', train=True, transform=transforms, download=True)

100%|██████████| 9.91M/9.91M [00:39<00:00, 249kB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 115kB/s]
100%|██████████| 1.65M/1.65M [00:09<00:00, 178kB/s] 
100%|██████████| 4.54k/4.54k [00:00<00:00, 883kB/s]


In [4]:
batch_size = 32
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

In [5]:
# Lấy một batch từ data_loader
images, labels = next(iter(data_loader))

# In kích thước của batch và ảnh
print(f"Batch size: {images.shape}")


Batch size: torch.Size([32, 1, 32, 32])


### Xây dựng model

In [6]:
# xây dựng Generator bằng MLP
# layer cuối gần cuối có 1024 neuron, layer cuối cùng có ảnh có kích thước 64x64
# reshape lại ảnh thành 64x64 (1 channel)

class Generator(nn.Module):
    def __init__(self, image_shape, num_classes, embedding_dim, latent_dim):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim

        self.label_embedding = nn.Embedding(num_classes, embedding_dim)

        self.model = nn.Sequential(
            nn.Linear(latent_dim + embedding_dim, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(1024, int(np.prod(image_shape))),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Tanh()
        )

    def forward(self, z: Tensor, labels) -> Tensor:
        label_embed = self.label_embedding(labels)
        input = torch.cat((z, label_embed), dim=1)
        output = self.model(input)
        output = output.view(output.size(0), *image_shape)
        return output

In [7]:
class Discriminator(nn.Module):
    def __init__(self, image_shape, num_classes, embedding_dim, latent_dim):
        super(Discriminator, self).__init__()

        self.label_embedding = nn.Embedding(num_classes, embedding_dim)

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(image_shape)) + embedding_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 1),
            nn.Sigmoid()

        )


    def forward(self, x: Tensor, labels) -> Tensor:
        output = x.view(x.size(0), -1)
        label_embed = self.label_embedding(labels)
        output = torch.cat((output, label_embed), dim=1)
        output = self.model(output)
        return output

### Training model

In [8]:
import os
os.makedirs('images', exist_ok=True)

save_interval = 10

In [9]:
images_batch, labels = next(iter(data_loader))
embedding_dim = 16
latent_dim = 100
image_channels = images_batch.size(1) #1
image_shape = (image_channels, image_size, image_size)
generator = Generator(image_shape, num_classes, embedding_dim, latent_dim).to(device)
discriminator = Discriminator(image_shape, num_classes, embedding_dim, latent_dim).to(device)

In [11]:
from tqdm import tqdm

EPOCHS = 150

optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0001)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002)

criterion = nn.BCELoss()

g_losses, d_losses = [], []

for epoch in range(EPOCHS):
  with tqdm(total=len(data_loader), desc=f"Epoch {epoch+1}/{EPOCHS}", unit="batch") as pbar:
    for i, (imgs,labels) in enumerate(data_loader):

      imgs = imgs.to(device)
      real_labels = torch.ones(imgs.size(0), 1).to(device)
      fake_labels = torch.zeros(imgs.size(0), 1).to(device)
      labels = labels.to(device)


      ############ Generator ##########
      noise = torch.randn(imgs.size(0), latent_dim).to(device)

      fake_imgs = generator(noise, labels)
      g_dis_output = discriminator(fake_imgs, labels)
      g_loss = criterion(g_dis_output, real_labels)
      g_losses.append(g_loss.item())

      optimizer_G.zero_grad()
      g_loss.backward()
      optimizer_G.step()

      ############ Discriminator ##########

      real_dis_output = discriminator(imgs, labels)
      real_loss = criterion(real_dis_output, real_labels)

      fake_dis_output = discriminator(fake_imgs.detach(), labels)
      fake_loss = criterion(fake_dis_output, fake_labels)

      d_loss = (real_loss + fake_loss)/2
      d_losses.append(d_loss.item())

      optimizer_D.zero_grad()
      d_loss.backward()
      optimizer_D.step()

      # Cập nhật thanh tiến trình
      pbar.update(1)



  if epoch % save_interval == 0:
        save_image(fake_imgs.data[:25], f"images/epoch_{epoch}.png", nrow=5, normalize=True)



Epoch 1/150: 100%|██████████| 1875/1875 [00:41<00:00, 45.64batch/s]
Epoch 2/150: 100%|██████████| 1875/1875 [00:38<00:00, 48.69batch/s]
Epoch 3/150: 100%|██████████| 1875/1875 [00:53<00:00, 35.34batch/s]
Epoch 4/150: 100%|██████████| 1875/1875 [01:00<00:00, 30.92batch/s]
Epoch 5/150: 100%|██████████| 1875/1875 [00:50<00:00, 37.15batch/s]
Epoch 6/150: 100%|██████████| 1875/1875 [00:30<00:00, 61.19batch/s]
Epoch 7/150: 100%|██████████| 1875/1875 [00:36<00:00, 52.04batch/s]
Epoch 8/150: 100%|██████████| 1875/1875 [00:35<00:00, 53.08batch/s]
Epoch 9/150: 100%|██████████| 1875/1875 [00:43<00:00, 43.33batch/s]
Epoch 10/150: 100%|██████████| 1875/1875 [00:43<00:00, 43.10batch/s]
Epoch 11/150: 100%|██████████| 1875/1875 [00:39<00:00, 47.05batch/s]
Epoch 12/150: 100%|██████████| 1875/1875 [01:00<00:00, 30.82batch/s]
Epoch 13/150: 100%|██████████| 1875/1875 [01:01<00:00, 30.31batch/s]
Epoch 14/150: 100%|██████████| 1875/1875 [01:01<00:00, 30.64batch/s]
Epoch 15/150: 100%|██████████| 1875/1875 [0

KeyboardInterrupt: 

# 4. Inference

In [None]:
%matplotlib inline
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

In [None]:
generator.eval()

num_sample = 5
for i in range(num_classes):
    target_class = i
    z = torch.randn((num_sample, latent_dim)).to(device)
    condition_labels = torch.full((num_sample,), target_class, dtype=torch.long).to(device)

    gen_imgs = generator(z, condition_labels).detach().cpu()

    grid = make_grid(gen_imgs, nrow=num_sample, normalize=True).permute(1,2,0).numpy()
    plt.imshow(grid)
    plt.show()