# MLP-GAN và Conv-GAN trên MNIST

Notebook này chứa hai biến thể:
- **MLP-GAN** trên ảnh MNIST 28×28
- **Conv-GAN** trên ảnh MNIST resize 64×64

Bạn có thể chỉnh tham số trực tiếp trong cell cấu hình bên dưới.

In [1]:
# Cài đặt (chỉ cần chạy nếu thiếu thư viện)
# !pip install torch torchvision matplotlib


In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [10]:
# Cấu hình tham số (chỉnh trực tiếp ở đây)
BASE_DIR = "/kaggle/working/"
ARCH = "mlp"  # 'mlp' or 'conv'
EPOCHS = 20
BATCH_SIZE = 128
LATENT_DIM = 100
LR = 2e-4
OUTPUT_DIR = os.path.join(BASE_DIR, "outputs_mnist")


In [4]:
# Định nghĩa MLP-GAN

class MLPGenerator(nn.Module):
    def __init__(self, latent_dim=100, img_size=28 * 28):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, img_size),
            nn.Tanh()
        )
        self.img_size = img_size

    def forward(self, z):
        x = self.model(z)
        return x.view(x.size(0), 1, 28, 28)


class MLPDiscriminator(nn.Module):
    def __init__(self, img_size=28 * 28):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(img_size, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.model(x)


In [5]:
# Định nghĩa Conv-GAN (kiểu DCGAN nhỏ cho MNIST 64x64)

class ConvGenerator(nn.Module):
    def __init__(self, nz=100, ngf=64, nc=1):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.main(z)


class ConvDiscriminator(nn.Module):
    def __init__(self, nc=1, ndf=64):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        out = self.main(x)
        return out.view(-1, 1)


In [6]:
# Hàm sinh noise

def sample_noise_mlp(batch_size, latent_dim, device):
    return torch.randn(batch_size, latent_dim, device=device)


def sample_noise_conv(batch_size, nz, device):
    return torch.randn(batch_size, nz, 1, 1, device=device)


In [7]:
# Vòng lặp huấn luyện chung cho cả MLP và Conv

def train_gan_mnist(arch="mlp", epochs=20, batch_size=128, latent_dim=100, lr=2e-4, outdir="outputs_mnist"):
    os.makedirs(outdir, exist_ok=True)

    if arch == "mlp":
        img_size = 28
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
    else:
        img_size = 64
        transform = transforms.Compose([
            transforms.Resize(img_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

    dataset = datasets.MNIST(root="data", train=True, transform=transform, download=True)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    if arch == "mlp":
        G = MLPGenerator(latent_dim, img_size * img_size).to(device)
        D = MLPDiscriminator(img_size * img_size).to(device)
    else:
        G = ConvGenerator(nz=latent_dim, ngf=64, nc=1).to(device)
        D = ConvDiscriminator(nc=1, ndf=64).to(device)

    criterion = nn.BCELoss()
    optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

    for epoch in range(epochs):
        for i, (real_imgs, _) in enumerate(loader):
            real_imgs = real_imgs.to(device)
            cur_bs = real_imgs.size(0)

            # Train D
            optimizer_D.zero_grad()
            labels_real = torch.ones(cur_bs, 1, device=device)
            labels_fake = torch.zeros(cur_bs, 1, device=device)

            out_real = D(real_imgs)
            loss_D_real = criterion(out_real, labels_real)

            if arch == "mlp":
                z = sample_noise_mlp(cur_bs, latent_dim, device)
            else:
                z = sample_noise_conv(cur_bs, latent_dim, device)

            fake_imgs = G(z).detach()
            out_fake = D(fake_imgs)
            loss_D_fake = criterion(out_fake, labels_fake)

            loss_D = loss_D_real + loss_D_fake
            loss_D.backward()
            optimizer_D.step()

            # Train G
            optimizer_G.zero_grad()
            if arch == "mlp":
                z = sample_noise_mlp(cur_bs, latent_dim, device)
            else:
                z = sample_noise_conv(cur_bs, latent_dim, device)

            fake_imgs = G(z)
            out_fake_for_G = D(fake_imgs)
            loss_G = criterion(out_fake_for_G, labels_real)
            loss_G.backward()
            optimizer_G.step()

            if i % 200 == 0:
                print(f"[{arch}] Epoch [{epoch+1}/{epochs}] "
                      f"Step [{i}/{len(loader)}] "
                      f"Loss_D: {loss_D.item():.4f}, Loss_G: {loss_G.item():.4f}")

        # Lưu ảnh mỗi epoch
        with torch.no_grad():
            if arch == "mlp":
                z = sample_noise_mlp(64, latent_dim, device)
            else:
                z = sample_noise_conv(64, latent_dim, device)
            fake = G(z).cpu()
            fake = (fake + 1) / 2
            grid = make_grid(fake, nrow=8)
            save_image(grid, os.path.join(outdir, f"{arch}_fake_epoch_{epoch+1:03d}.png"))

    print("Hoàn thành huấn luyện, ảnh sinh ra lưu trong:", outdir)


In [8]:
# CHẠY THỰC NGHIỆM
# Chỉ cần chạy cell này sau khi chỉnh tham số ở cell cấu hình.

train_gan_mnist(
    arch=ARCH,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    latent_dim=LATENT_DIM,
    lr=LR,
    outdir=OUTPUT_DIR
)


100%|██████████| 9.91M/9.91M [00:00<00:00, 17.9MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 479kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.47MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.37MB/s]


[conv] Epoch [1/20] Step [0/469] Loss_D: 1.5851, Loss_G: 1.8070
[conv] Epoch [1/20] Step [200/469] Loss_D: 0.5102, Loss_G: 4.7396
[conv] Epoch [1/20] Step [400/469] Loss_D: 1.0117, Loss_G: 3.8880
[conv] Epoch [2/20] Step [0/469] Loss_D: 0.8209, Loss_G: 0.7068
[conv] Epoch [2/20] Step [200/469] Loss_D: 0.8529, Loss_G: 2.2823
[conv] Epoch [2/20] Step [400/469] Loss_D: 0.3622, Loss_G: 1.8367
[conv] Epoch [3/20] Step [0/469] Loss_D: 0.7558, Loss_G: 2.1386
[conv] Epoch [3/20] Step [200/469] Loss_D: 0.6497, Loss_G: 1.3512
[conv] Epoch [3/20] Step [400/469] Loss_D: 1.1309, Loss_G: 1.2816
[conv] Epoch [4/20] Step [0/469] Loss_D: 1.3422, Loss_G: 7.9889
[conv] Epoch [4/20] Step [200/469] Loss_D: 0.2854, Loss_G: 3.6206
[conv] Epoch [4/20] Step [400/469] Loss_D: 0.4566, Loss_G: 4.2112
[conv] Epoch [5/20] Step [0/469] Loss_D: 0.0675, Loss_G: 4.3343
[conv] Epoch [5/20] Step [200/469] Loss_D: 0.0495, Loss_G: 5.1526
[conv] Epoch [5/20] Step [400/469] Loss_D: 0.3908, Loss_G: 3.0748
[conv] Epoch [6/20] 

In [11]:
train_gan_mnist(
    arch=ARCH,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    latent_dim=LATENT_DIM,
    lr=LR,
    outdir=OUTPUT_DIR
)

[mlp] Epoch [1/20] Step [0/469] Loss_D: 1.4150, Loss_G: 0.6960
[mlp] Epoch [1/20] Step [200/469] Loss_D: 1.2797, Loss_G: 0.8584
[mlp] Epoch [1/20] Step [400/469] Loss_D: 1.4556, Loss_G: 0.9871
[mlp] Epoch [2/20] Step [0/469] Loss_D: 0.9636, Loss_G: 1.3793
[mlp] Epoch [2/20] Step [200/469] Loss_D: 0.8600, Loss_G: 1.5466
[mlp] Epoch [2/20] Step [400/469] Loss_D: 1.1156, Loss_G: 1.2645
[mlp] Epoch [3/20] Step [0/469] Loss_D: 0.9919, Loss_G: 1.3817
[mlp] Epoch [3/20] Step [200/469] Loss_D: 1.1555, Loss_G: 2.6992
[mlp] Epoch [3/20] Step [400/469] Loss_D: 0.8441, Loss_G: 1.5528
[mlp] Epoch [4/20] Step [0/469] Loss_D: 0.9335, Loss_G: 0.9194
[mlp] Epoch [4/20] Step [200/469] Loss_D: 1.0834, Loss_G: 1.9505
[mlp] Epoch [4/20] Step [400/469] Loss_D: 0.8373, Loss_G: 2.7357
[mlp] Epoch [5/20] Step [0/469] Loss_D: 0.7174, Loss_G: 1.3401
[mlp] Epoch [5/20] Step [200/469] Loss_D: 0.9970, Loss_G: 1.5264
[mlp] Epoch [5/20] Step [400/469] Loss_D: 0.8729, Loss_G: 1.1432
[mlp] Epoch [6/20] Step [0/469] Los