In [1]:
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils


In [2]:
def load_data(dataset_name='cifar10', dataroot='./data', image_size=64, batch_size=64):
    if dataset_name == 'cifar10':
        dataset = dset.CIFAR10(
            root=dataroot, download=True,
            transform=transforms.Compose([
                transforms.Resize(image_size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
        )
        nc = 3  # CIFAR-10 có 3 kênh (RGB)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True, num_workers=1
    )
    return dataloader, nc

In [3]:
class Generator(nn.Module):
    def __init__(self, nz, ngf, nc):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)


In [4]:
class Critic(nn.Module):
    def __init__(self, ndf, nc):
        super(Critic, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False)
        )

    def forward(self, input):
        return self.main(input).view(-1)  # Output is a scalar score per input


In [5]:
def gradient_penalty(critic, real_data, fake_data, device):
    batch_size = real_data.size(0)
    epsilon = torch.rand(batch_size, 1, 1, 1, device=device)
    interpolated = epsilon * real_data + (1 - epsilon) * fake_data
    interpolated.requires_grad_(True)

    score_interpolated = critic(interpolated)

    grad_outputs = torch.ones_like(score_interpolated, device=device)
    gradients = torch.autograd.grad(
        outputs=score_interpolated,
        inputs=interpolated,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    penalty = ((gradient_norm - 1) ** 2).mean()
    return penalty


In [6]:
nz = 100  # Latent vector size
ngf = 64  # Generator feature map size
ndf = 64  # Critic feature map size
ngpu = 1  # Number of GPUs

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

netG = Generator(ngpu, nz, ngf, 3).to(device)
netC = Critic(ngpu, ndf, 3).to(device)

optimizerG = optim.Adam(netG.parameters(), lr=0.0001, betas=(0.0, 0.9))
optimizerC = optim.Adam(netC.parameters(), lr=0.0001, betas=(0.0, 0.9))


In [7]:
n_critic = 5  # Update critic 5 times per generator update
lambda_gp = 10  # Gradient penalty weight
# Khởi tạo dataloader và số lượng kênh
dataloader, nc = load_data('cifar10')

for epoch in range(25):
    for i, data in enumerate(dataloader):
        # Train Critic
        netC.zero_grad()
        real_data = data[0].to(device)

        batch_size = real_data.size(0)
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake_data = netG(noise).detach()

        real_score = netC(real_data).mean()
        fake_score = netC(fake_data).mean()

        # Compute gradient penalty
        gp = gradient_penalty(netC, real_data, fake_data, device)

        # Critic loss
        lossC = fake_score - real_score + lambda_gp * gp
        lossC.backward()
        optimizerC.step()

        # Update Generator every n_critic iterations
        if i % n_critic == 0:
            netG.zero_grad()
            noise = torch.randn(batch_size, nz, 1, 1, device=device)
            fake_data = netG(noise)
            lossG = -netC(fake_data).mean()
            lossG.backward()
            optimizerG.step()

        if i % 100 == 0:
            print(f'[{epoch}/{25}][{i}/{len(dataloader)}] '
                  f'Loss_C: {lossC.item():.4f} Loss_G: {lossG.item():.4f}')


Files already downloaded and verified
[0/25][0/782] Loss_C: 212.3717 Loss_G: -0.0305
[0/25][100/782] Loss_C: -1.0466 Loss_G: 1.0596
[0/25][200/782] Loss_C: -0.3749 Loss_G: 0.6686
[0/25][300/782] Loss_C: -0.2862 Loss_G: 0.5059
[0/25][400/782] Loss_C: -0.2509 Loss_G: 0.8586
[0/25][500/782] Loss_C: -0.3415 Loss_G: 1.1334
[0/25][600/782] Loss_C: -0.1243 Loss_G: 1.1573
[0/25][700/782] Loss_C: -0.0881 Loss_G: 1.1234
[1/25][0/782] Loss_C: -0.0946 Loss_G: 1.2625
[1/25][100/782] Loss_C: -0.1168 Loss_G: 1.3746
[1/25][200/782] Loss_C: -0.1947 Loss_G: 1.5372
[1/25][300/782] Loss_C: -0.2804 Loss_G: 1.5514
[1/25][400/782] Loss_C: -0.2805 Loss_G: 1.6638
[1/25][500/782] Loss_C: -0.2814 Loss_G: 1.4599
[1/25][600/782] Loss_C: -0.3002 Loss_G: 1.4184
[1/25][700/782] Loss_C: -0.2523 Loss_G: 1.2166
[2/25][0/782] Loss_C: -0.0579 Loss_G: 1.2048
[2/25][100/782] Loss_C: -0.2658 Loss_G: 1.1655
[2/25][200/782] Loss_C: -0.2455 Loss_G: 1.1759
[2/25][300/782] Loss_C: -0.1925 Loss_G: 1.1379
[2/25][400/782] Loss_C: -0