In [1]:
#!/usr/bin/env python
# coding: utf-8

"""Minimal Pix2Pix toy training for paired A|B images. Run directly with defaults."""

import glob
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

BASE_DIR = "/kaggle/input/data-36/"
DATA_DIR = "/kaggle/input/data-36/data/pix2pix_toy/train"
OUTPUT_DIR = "/kaggle/working/outputs_pix2pix_toy"
EPOCHS = 50
BATCH_SIZE = 4
LR = 2e-4
LAMBDA_L1 = 100.0
IMG_SIZE = 256


class Pix2PixDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        super().__init__()
        self.root_dir = root_dir
        self.paths = sorted(glob.glob(os.path.join(root_dir, "*.png")))
        if len(self.paths) == 0:
            raise RuntimeError(f"No .png files found in {root_dir}")
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        img = Image.open(path).convert("RGB")
        w, h = img.size
        w2 = w // 2
        img_a = img.crop((0, 0, w2, h))
        img_b = img.crop((w2, 0, w, h))
        if self.transform:
            img_a = self.transform(img_a)
            img_b = self.transform(img_b)
        return img_a, img_b


class UNetDown(nn.Module):
    def __init__(self, in_ch, out_ch, norm=True):
        super().__init__()
        layers = [nn.Conv2d(in_ch, out_ch, 4, 2, 1, bias=False)]
        if norm:
            layers.append(nn.BatchNorm2d(out_ch))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class UNetUp(nn.Module):
    def __init__(self, in_ch, out_ch, dropout=0.0):
        super().__init__()
        layers = [
            nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(True),
        ]
        if dropout > 0:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip):
        x = self.model(x)
        x = torch.cat([x, skip], dim=1)
        return x


class UNetGenerator(nn.Module):
    def __init__(self, in_ch=3, out_ch=3):
        super().__init__()
        self.down1 = UNetDown(in_ch, 64, norm=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512)
        self.down5 = UNetDown(512, 512)
        self.down6 = UNetDown(512, 512)
        self.down7 = UNetDown(512, 512, norm=False)

        self.up1 = UNetUp(512, 512, dropout=0.5)
        self.up2 = UNetUp(1024, 512, dropout=0.5)
        self.up3 = UNetUp(1024, 512, dropout=0.5)
        self.up4 = UNetUp(1024, 256)
        self.up5 = UNetUp(512, 128)
        self.up6 = UNetUp(256, 64)

        self.last = nn.Sequential(nn.ConvTranspose2d(128, out_ch, 4, 2, 1), nn.Tanh())

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)

        u1 = self.up1(d7, d6)
        u2 = self.up2(u1, d5)
        u3 = self.up3(u2, d4)
        u4 = self.up4(u3, d3)
        u5 = self.up5(u4, d2)
        u6 = self.up6(u5, d1)
        out = self.last(u6)
        return out


class PatchDiscriminator(nn.Module):
    def __init__(self, in_ch=6):
        super().__init__()

        def block(in_c, out_c, norm=True):
            layers = [nn.Conv2d(in_c, out_c, 4, 2, 1, bias=False)]
            if norm:
                layers.append(nn.BatchNorm2d(out_c))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(in_ch, 64, norm=False),
            *block(64, 128),
            *block(128, 256),
            nn.Conv2d(256, 1, 4, 1, 1),
        )

    def forward(self, x, y):
        inp = torch.cat([x, y], dim=1)
        return self.model(inp)


def denorm(x):
    return (x + 1) / 2


def train_pix2pix(
    data_dir=DATA_DIR,
    outdir=OUTPUT_DIR,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    lr=LR,
    lambda_L1=LAMBDA_L1,
    img_size=IMG_SIZE,
):
    os.makedirs(outdir, exist_ok=True)

    transform = transforms.Compose(
        [
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    dataset = Pix2PixDataset(data_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    G = UNetGenerator(in_ch=3, out_ch=3).to(device)
    D = PatchDiscriminator(in_ch=6).to(device)

    bce = nn.BCEWithLogitsLoss()
    l1 = nn.L1Loss()

    optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

    for epoch in range(epochs):
        for i, (img_a, img_b) in enumerate(loader):
            img_a = img_a.to(device)
            img_b = img_b.to(device)

            optimizer_D.zero_grad()
            fake_b = G(img_a).detach()

            pred_real = D(img_a, img_b)
            pred_fake = D(img_a, fake_b)

            target_real = torch.ones_like(pred_real)
            target_fake = torch.zeros_like(pred_fake)

            loss_D_real = bce(pred_real, target_real)
            loss_D_fake = bce(pred_fake, target_fake)
            loss_D = (loss_D_real + loss_D_fake) * 0.5
            loss_D.backward()
            optimizer_D.step()

            optimizer_G.zero_grad()
            fake_b = G(img_a)
            pred_fake_for_G = D(img_a, fake_b)
            target_real_for_G = torch.ones_like(pred_fake_for_G)

            loss_G_GAN = bce(pred_fake_for_G, target_real_for_G)
            loss_G_L1 = l1(fake_b, img_b) * lambda_L1
            loss_G = loss_G_GAN + loss_G_L1
            loss_G.backward()
            optimizer_G.step()

            if i % 50 == 0:
                print(
                    f"[Pix2Pix] Epoch [{epoch+1}/{epochs}] "
                    f"Step [{i}/{len(loader)}] "
                    f"Loss_D: {loss_D.item():.4f}, Loss_G: {loss_G.item():.4f}"
                )

        G.eval()
        with torch.no_grad():
            img_a_sample, img_b_sample = next(iter(loader))
            img_a_sample = img_a_sample.to(device)
            img_b_sample = img_b_sample.to(device)
            fake_b_sample = G(img_a_sample)

            grid = torch.cat([denorm(img_a_sample), denorm(img_b_sample), denorm(fake_b_sample)], dim=0)
            save_image(grid, os.path.join(outdir, f"epoch_{epoch+1:03d}.png"), nrow=img_a_sample.size(0))
        G.train()

    print("Training complete. Images saved to:", outdir)


if __name__ == "__main__":
    train_pix2pix()

Using device: cuda
[Pix2Pix] Epoch [1/50] Step [0/50] Loss_D: 0.7096, Loss_G: 88.7422
[Pix2Pix] Epoch [2/50] Step [0/50] Loss_D: 0.0943, Loss_G: 41.0403
[Pix2Pix] Epoch [3/50] Step [0/50] Loss_D: 0.0863, Loss_G: 35.8605
[Pix2Pix] Epoch [4/50] Step [0/50] Loss_D: 0.6972, Loss_G: 38.1654
[Pix2Pix] Epoch [5/50] Step [0/50] Loss_D: 0.1545, Loss_G: 31.2015
[Pix2Pix] Epoch [6/50] Step [0/50] Loss_D: 0.2015, Loss_G: 33.5882
[Pix2Pix] Epoch [7/50] Step [0/50] Loss_D: 0.1627, Loss_G: 33.9019
[Pix2Pix] Epoch [8/50] Step [0/50] Loss_D: 0.3330, Loss_G: 28.8964
[Pix2Pix] Epoch [9/50] Step [0/50] Loss_D: 0.2204, Loss_G: 33.3103
[Pix2Pix] Epoch [10/50] Step [0/50] Loss_D: 0.2215, Loss_G: 26.0978
[Pix2Pix] Epoch [11/50] Step [0/50] Loss_D: 0.2143, Loss_G: 28.7904
[Pix2Pix] Epoch [12/50] Step [0/50] Loss_D: 0.3388, Loss_G: 32.3535
[Pix2Pix] Epoch [13/50] Step [0/50] Loss_D: 0.4621, Loss_G: 29.5012
[Pix2Pix] Epoch [14/50] Step [0/50] Loss_D: 0.3156, Loss_G: 19.4465
[Pix2Pix] Epoch [15/50] Step [0/50] Lo