In [None]:


# Cell 1: Import

import os, glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
from tqdm import tqdm

print("CUDA available:", torch.cuda.is_available())
print("GPU count:", torch.cuda.device_count())


# Cell 2: Dataset

class LHQDataset(Dataset):
    def __init__(self, root, split='train', transform=None):
        self.transform = transform

        if split not in ['train', 'valid', 'test']:
            raise ValueError("split must be in ['train', 'valid', 'test']")
        data_path = os.path.join(root, split)


        self.image_paths = []
        for ext in ['jpg', 'jpeg', 'png', 'bmp', 'webp']:
            self.image_paths += glob.glob(os.path.join(data_path, f'*.{ext}'))

        if len(self.image_paths) == 0:
            raise RuntimeError(f"No images found in {data_path}")

    def __getitem__(self, index):
        path = self.image_paths[index]
        with Image.open(path) as img:
            image = img.convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

    def __len__(self):
        return len(self.image_paths)

# Cell 3: Model (Generator & Discriminator)

class Generator(nn.Module):
    def __init__(self, z_dim=128, channels_img=3, features_g=64):
        super().__init__()
        self.net = nn.Sequential(
            self._block(z_dim, features_g*16, 4, 1, 0),   # 4x4
            self._block(features_g*16, features_g*8, 4, 2, 1), # 8x8
            self._block(features_g*8, features_g*4, 4, 2, 1),  # 16x16
            self._block(features_g*4, features_g*2, 4, 2, 1),  # 32x32
            nn.ConvTranspose2d(features_g*2, channels_img, 4, 2, 1), # 64x64
            nn.Tanh()
        )

    def _block(self, in_c, out_c, k, s, p):
        return nn.Sequential(
            nn.ConvTranspose2d(in_c, out_c, k, s, p, bias=False),
            nn.BatchNorm2d(out_c),
            nn.ReLU(True)
        )

    def forward(self, x): return self.net(x)


class Discriminator(nn.Module):
    def __init__(self, channels_img=3, features_d=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(channels_img, features_d, 4, 2, 1), # 32x32
            nn.LeakyReLU(0.2, inplace=True),
            self._block(features_d, features_d*2, 4, 2, 1), # 16x16
            self._block(features_d*2, features_d*4, 4, 2, 1), # 8x8
            self._block(features_d*4, features_d*8, 4, 2, 1), # 4x4
            nn.Conv2d(features_d*8, 1, 4, 1, 0),
        )

    def _block(self, in_c, out_c, k, s, p):
        return nn.Sequential(
            nn.Conv2d(in_c, out_c, k, s, p, bias=False),
            nn.BatchNorm2d(out_c),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, x): return self.net(x)


CUDA available: False
GPU count: 0


In [None]:

# Cell 4: Train setup

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)

lr = 2e-4
z_dim = 128
batch_size = 4096
image_size = 64
channels_img = 3
epochs = 2000 
features_g = 64
features_d = 64


os.makedirs("/kaggle/working/samples", exist_ok=True)
os.makedirs("/kaggle/working/checkpoints", exist_ok=True)

# Dataset
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*channels_img, [0.5]*channels_img),
])


dataset = LHQDataset(root="/kaggle/input/datasetlhq/dataset_LHQ_64_quantize_16", split='train', transform=transform)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)

# Models
gen = Generator(z_dim, channels_img, features_g)
disc = Discriminator(channels_img, features_d)

# Multi-GPU
if torch.cuda.device_count() > 1:
    print("Using DataParallel with", torch.cuda.device_count(), "GPUs")
    gen = nn.DataParallel(gen)
    disc = nn.DataParallel(disc)

gen, disc = gen.to(device), disc.to(device)

# Optimizers & loss
opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))
criterion = nn.BCEWithLogitsLoss()

fixed_noise = torch.randn(64, z_dim, 1, 1).to(device)

# Cell 5: Training loop

for epoch in range(epochs):
    loop = tqdm(loader, colour='magenta')
    for idx, real in enumerate(loop):
        real = real.to(device)
        noise = torch.randn(batch_size, z_dim, 1, 1, device=device)
        fake = gen(noise)

        # Train Discriminator
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake.detach()).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        opt_disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        # Train Generator
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        opt_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        loop.set_description(f"Epoch [{epoch}/{epochs}]")
        loop.set_postfix(loss_gen=loss_gen.item(), loss_disc=loss_disc.item())

        # Save samples
        if idx == len(loop) - 1:
            with torch.no_grad():
                fake = gen(fixed_noise)
                fake = fake * 0.5 + 0.5
                save_image(fake, f"/kaggle/working/samples/fake_epoch_{epoch}_{idx}.png", nrow=8)

    # Save checkpoints
    torch.save({
        'gen_state_dict': gen.module.state_dict() if isinstance(gen, nn.DataParallel) else gen.state_dict(),
        'disc_state_dict': disc.module.state_dict() if isinstance(disc, nn.DataParallel) else disc.state_dict(),
        'opt_gen_state_dict': opt_gen.state_dict(),
        'opt_disc_state_dict': opt_disc.state_dict(),
        'epoch': epoch,
    }, f"checkpoints/gan_checkpoint_epoch_{epoch}.pth")

Using: cpu


In [None]:

# Cell 6: Inference

def infer(gen, z_dim=128, num_images=100, out_dir="/kaggle/working/infer_samples", batch_size=64, seed=None):
    os.makedirs(out_dir, exist_ok=True)

    device = next(gen.parameters()).device
    gen.eval()

    if seed is not None:
        torch.manual_seed(seed)

    total_batches = (num_images + batch_size - 1) // batch_size
    img_count = 0

    with torch.no_grad():
        for b in range(total_batches):
            current_batch = min(batch_size, num_images - img_count)
            noise = torch.randn(current_batch, z_dim, 1, 1, device=device)
            fake = gen(noise)
            fake = fake * 0.5 + 0.5  # từ [-1,1] về [0,1]

            for i in range(current_batch):
                save_path = os.path.join(out_dir, f"generated_{img_count+i:04d}.png")
                save_image(fake[i], save_path)

            img_count += current_batch

    print(f"Đã sinh {num_images} ảnh vào thư mục: {out_dir}")


checkpoint = torch.load("../kaggle/working/checkpoints/gan_checkpoint_epoch_98.pth", map_location=device)
gen.load_state_dict(checkpoint['gen_state_dict'])

infer(gen, z_dim=z_dim, num_images=1000, out_dir="./gen_samples", batch_size=32, seed=42)


  checkpoint = torch.load("../kaggle/working/checkpoints/gan_checkpoint_epoch_98.pth", map_location=device)


Đã sinh 1000 ảnh vào thư mục: ./gen_samples
