# CycleGAN (toy) với dữ liệu unpaired

Domain X: thư mục trainA, Domain Y: thư mục trainB.

In [9]:
# !pip install torch torchvision pillow

In [10]:
import os
import glob
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [11]:
# Cấu hình
ROOT_PATH = "/kaggle/input/data-36"
DATA_A = ROOT_PATH + "/data/cyclegan_toy/trainA"
DATA_B = ROOT_PATH + "/data/cyclegan_toy/trainB"
OUTPUT_DIR =  "/kaggle/working/outputs_cyclegan_toy"
EPOCHS = 100
BATCH_SIZE = 2
LR = 2e-4
LAMBDA_CYCLE = 10.0
LAMBDA_ID = 5.0
IMG_SIZE = 128


In [12]:
class UnpairedDataset(Dataset):
    def __init__(self, dir_A, dir_B, transform=None):
        self.paths_A = sorted(glob.glob(os.path.join(dir_A, "*.png")))
        self.paths_B = sorted(glob.glob(os.path.join(dir_B, "*.png")))
        if len(self.paths_A) == 0 or len(self.paths_B) == 0:
            raise RuntimeError("No images in A or B")
        self.transform = transform

    def __len__(self):
        return max(len(self.paths_A), len(self.paths_B))

    def __getitem__(self, idx):
        path_A = self.paths_A[idx % len(self.paths_A)]
        path_B = self.paths_B[random.randint(0, len(self.paths_B) - 1)]
        img_A = Image.open(path_A).convert("RGB")
        img_B = Image.open(path_B).convert("RGB")
        if self.transform:
            img_A = self.transform(img_A)
            img_B = self.transform(img_B)
        return img_A, img_B


In [13]:
class ResnetBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(dim, dim, 3, 1, 1, bias=False),
            nn.InstanceNorm2d(dim),
            nn.ReLU(True),
            nn.Conv2d(dim, dim, 3, 1, 1, bias=False),
            nn.InstanceNorm2d(dim),
        )

    def forward(self, x):
        return x + self.block(x)


class ResnetGenerator(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, n_blocks=6, ngf=64):
        super().__init__()
        model = [
            nn.Conv2d(in_ch, ngf, 7, 1, 3, bias=False),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(True)
        ]
        cur_dim = ngf
        for _ in range(2):
            model += [
                nn.Conv2d(cur_dim, cur_dim * 2, 3, 2, 1, bias=False),
                nn.InstanceNorm2d(cur_dim * 2),
                nn.ReLU(True)
            ]
            cur_dim *= 2
        for _ in range(n_blocks):
            model += [ResnetBlock(cur_dim)]
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(cur_dim, cur_dim // 2, 3, 2, 1, output_padding=1, bias=False),
                nn.InstanceNorm2d(cur_dim // 2),
                nn.ReLU(True)
            ]
            cur_dim //= 2
        model += [
            nn.Conv2d(cur_dim, out_ch, 7, 1, 3),
            nn.Tanh()
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


In [14]:
class PatchDiscriminator(nn.Module):
    def __init__(self, in_ch=3, ndf=64):
        super().__init__()
        def block(in_c, out_c, norm=True):
            layers = [nn.Conv2d(in_c, out_c, 4, 2, 1, bias=False)]
            if norm:
                layers.append(nn.InstanceNorm2d(out_c))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(in_ch, ndf, norm=False),
            *block(ndf, ndf * 2),
            *block(ndf * 2, ndf * 4),
            nn.Conv2d(ndf * 4, 1, 4, 1, 1)
        )

    def forward(self, x):
        return self.model(x)


In [15]:
def train_cyclegan(dataA=DATA_A,
                   dataB=DATA_B,
                   outdir=OUTPUT_DIR,
                   epochs=EPOCHS,
                   batch_size=BATCH_SIZE,
                   lr=LR,
                   lambda_cycle=LAMBDA_CYCLE,
                   lambda_id=LAMBDA_ID,
                   img_size=IMG_SIZE):
    os.makedirs(outdir, exist_ok=True)

    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5),
                             (0.5, 0.5, 0.5))
    ])

    dataset = UnpairedDataset(dataA, dataB, transform)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    G = ResnetGenerator(3, 3, n_blocks=3, ngf=64).to(device)
    F = ResnetGenerator(3, 3, n_blocks=3, ngf=64).to(device)
    D_X = PatchDiscriminator(3, ndf=64).to(device)
    D_Y = PatchDiscriminator(3, ndf=64).to(device)

    bce = nn.MSELoss()
    l1 = nn.L1Loss()

    opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_F = optim.Adam(F.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_DX = optim.Adam(D_X.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_DY = optim.Adam(D_Y.parameters(), lr=lr, betas=(0.5, 0.999))

    for epoch in range(epochs):
        for i, (real_X, real_Y) in enumerate(loader):
            real_X = real_X.to(device)
            real_Y = real_Y.to(device)
            cur_bs = real_X.size(0)

            valid_X = torch.ones(cur_bs, 1, 1, 1, device=device)
            fake_X_lbl = torch.zeros(cur_bs, 1, 1, 1, device=device)
            valid_Y = torch.ones(cur_bs, 1, 1, 1, device=device)
            fake_Y_lbl = torch.zeros(cur_bs, 1, 1, 1, device=device)

            opt_G.zero_grad()
            opt_F.zero_grad()

            fake_Y = G(real_X)
            fake_X = F(real_Y)

            pred_fake_Y = D_Y(fake_Y)
            pred_fake_X = D_X(fake_X)
            loss_G_gan = bce(pred_fake_Y, valid_Y)
            loss_F_gan = bce(pred_fake_X, valid_X)

            rec_X = F(fake_Y)
            rec_Y = G(fake_X)
            loss_cycle_X = l1(rec_X, real_X)
            loss_cycle_Y = l1(rec_Y, real_Y)
            loss_cycle_total = (loss_cycle_X + loss_cycle_Y) * lambda_cycle

            id_X = F(real_X)
            id_Y = G(real_Y)
            loss_id_X = l1(id_X, real_X) * lambda_id
            loss_id_Y = l1(id_Y, real_Y) * lambda_id

            loss_G_total = loss_G_gan + loss_F_gan + loss_cycle_total + loss_id_X + loss_id_Y
            loss_G_total.backward()
            opt_G.step()
            opt_F.step()

            opt_DX.zero_grad()
            pred_real_X = D_X(real_X)
            pred_fake_X = D_X(fake_X.detach())
            loss_DX_real = bce(pred_real_X, valid_X)
            loss_DX_fake = bce(pred_fake_X, fake_X_lbl)
            loss_DX = (loss_DX_real + loss_DX_fake) * 0.5
            loss_DX.backward()
            opt_DX.step()

            opt_DY.zero_grad()
            pred_real_Y = D_Y(real_Y)
            pred_fake_Y = D_Y(fake_Y.detach())
            loss_DY_real = bce(pred_real_Y, valid_Y)
            loss_DY_fake = bce(pred_fake_Y, fake_Y_lbl)
            loss_DY = (loss_DY_real + loss_DY_fake) * 0.5
            loss_DY.backward()
            opt_DY.step()

            if i % 50 == 0:
                print(f"[CycleGAN] Epoch [{epoch+1}/{epochs}] "
                      f"Step [{i}/{len(loader)}] "
                      f"Loss_G: {loss_G_total.item():.4f} "
                      f"Loss_DX: {loss_DX.item():.4f} "
                      f"Loss_DY: {loss_DY.item():.4f}")

        with torch.no_grad():
            real_X_sample, real_Y_sample = next(iter(loader))
            real_X_sample = real_X_sample.to(device)
            real_Y_sample = real_Y_sample.to(device)
            fake_Y_sample = G(real_X_sample)
            rec_X_sample = F(fake_Y_sample)

            def denorm(x):
                return (x + 1) / 2

            grid = torch.cat([denorm(real_X_sample),
                              denorm(fake_Y_sample),
                              denorm(rec_X_sample)], dim=0)
            save_image(grid, os.path.join(outdir, f"X2Y_epoch_{epoch+1:03d}.png"),
                       nrow=real_X_sample.size(0))

            fake_X_sample = F(real_Y_sample)
            rec_Y_sample = G(fake_X_sample)
            grid2 = torch.cat([denorm(real_Y_sample),
                               denorm(fake_X_sample),
                               denorm(rec_Y_sample)], dim=0)
            save_image(grid2, os.path.join(outdir, f"Y2X_epoch_{epoch+1:03d}.png"),
                       nrow=real_Y_sample.size(0))

    print("Huấn luyện CycleGAN xong, ảnh kết quả trong:", outdir)


In [16]:
train_cyclegan()

  return F.mse_loss(input, target, reduction=self.reduction)


[CycleGAN] Epoch [1/100] Step [0/250] Loss_G: 29.5296 Loss_DX: 0.8536 Loss_DY: 0.4803
[CycleGAN] Epoch [1/100] Step [50/250] Loss_G: 2.3372 Loss_DX: 0.2236 Loss_DY: 0.2253
[CycleGAN] Epoch [1/100] Step [100/250] Loss_G: 1.6240 Loss_DX: 0.2741 Loss_DY: 0.2885
[CycleGAN] Epoch [1/100] Step [150/250] Loss_G: 1.4008 Loss_DX: 0.2425 Loss_DY: 0.2980
[CycleGAN] Epoch [1/100] Step [200/250] Loss_G: 1.2899 Loss_DX: 0.2526 Loss_DY: 0.2462
[CycleGAN] Epoch [2/100] Step [0/250] Loss_G: 1.3399 Loss_DX: 0.2364 Loss_DY: 0.2402
[CycleGAN] Epoch [2/100] Step [50/250] Loss_G: 1.2483 Loss_DX: 0.2564 Loss_DY: 0.2558
[CycleGAN] Epoch [2/100] Step [100/250] Loss_G: 1.6086 Loss_DX: 0.1870 Loss_DY: 0.1743
[CycleGAN] Epoch [2/100] Step [150/250] Loss_G: 1.2373 Loss_DX: 0.2713 Loss_DY: 0.2509
[CycleGAN] Epoch [2/100] Step [200/250] Loss_G: 1.3401 Loss_DX: 0.2284 Loss_DY: 0.2175
[CycleGAN] Epoch [3/100] Step [0/250] Loss_G: 1.3601 Loss_DX: 0.2139 Loss_DY: 0.2166
[CycleGAN] Epoch [3/100] Step [50/250] Loss_G: 1.2