In [13]:
import os
from datetime import datetime
import glob
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils

# ========== Paths ==========
data_root = "./dataset/raw/img_align_celeba/img_align_celeba"
save_dir = "./dataset/save/"
os.makedirs(save_dir, exist_ok=True)

# ========== Hyperparams ==========
batch_size = 512
lr = 2e-4
num_epochs = 100
latent_dim = 128
sample_every = 5
num_sample_images = 8
image_size = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
])

In [14]:
# import os

# # Install the Kaggle API client
# !pip install kaggle

# # Create the .kaggle directory and copy the kaggle.json file
# !mkdir -p ~/.kaggle
# # You will need to upload your kaggle.json file to your Colab environment.
# # In the left sidebar, click on the folder icon, then click on the upload icon.
# # Upload your kaggle.json file there.
# # Once uploaded, move it to the correct directory:
# # !mv kaggle.json ~/.kaggle/

# # If you have stored your Kaggle API key and username as Colab secrets, you can use the following:
# from google.colab import userdata

# os.environ['KAGGLE_USERNAME'] = userdata.get('KAGGLE_USERNAME')
# os.environ['KAGGLE_KEY'] = userdata.get('KAGGLE_KEY')

# # Set permissions for the kaggle.json file
# !chmod 600 ~/.kaggle/kaggle.json

# # Download the dataset
# !kaggle datasets download -d jessicali9530/celeba-dataset -p ./dataset/raw --unzip

In [15]:
# ========== Dataset ==========
class CelebADataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.paths = sorted(glob.glob(os.path.join(root, "*.jpg")))
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img_path = self.paths[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img

dataset = CelebADataset(root=data_root, transform=transform)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True,
                    num_workers=2, pin_memory=True)
print("Num Images: ", len(dataset))

# ========== Utilities ==========
def save_checkpoint(model, optim, epoch, path):
    state = {
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optim_state": optim.state_dict()
    }
    torch.save(state, path)


def save_image_grid(tensor, filename, nrow=8):
    tensor = torch.clamp(tensor, 0, 1)
    utils.save_image(tensor, filename, nrow=nrow, padding=2)

Num Images:  202599


In [16]:
# ========== Loss ==========
def vae_loss(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(recon_x, x, reduction='sum') / x.size(0)
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)

    # β-VAE: 可以引入一个权重来平衡两项，初期可以设 beta=1
    beta = 1.0

    return recon_loss + beta * kld, recon_loss, kld


In [17]:
# ========== VAE Model ==========
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.ReLU(),
            nn.Flatten(),
        )
        self.fc_mu = nn.Linear(256 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(256 * 4 * 4, latent_dim)

        self.decoder_input = nn.Linear(latent_dim, 256 * 4 * 4)
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (256, 4, 4)),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, 2, 1),
            nn.Sigmoid(),
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.decoder_input(z)
        return self.decoder(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [18]:
from tqdm import tqdm
model = VAE(latent_dim=latent_dim).to(device)
optim = torch.optim.Adam(model.parameters(), lr=lr)
global_step = 0
for epoch in range(1, num_epochs + 1):
    model.train()
    epoch_loss = 0.0
    epoch_recon = 0.0
    epoch_kld = 0.0

    for batch_idx, imgs in enumerate(tqdm(loader)):
        imgs = imgs.to(device, non_blocking=True)
        optim.zero_grad()
        recon_imgs, mu, logvar = model(imgs)
        loss, recon_l, kld = vae_loss(recon_imgs, imgs, mu, logvar)
        loss.backward()
        optim.step()

        epoch_loss += loss.item()
        epoch_recon += recon_l.item()
        epoch_kld += kld.item()
        global_step += 1

        if batch_idx % 100 == 0:
            print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
                    f"Epoch {epoch}/{num_epochs} Batch {batch_idx}/{len(loader)} "
                    f"Loss {loss.item():.4f} "
                    f"(recon {recon_l.item():.4f}, kld {kld.item():.4f})")

    n_samples = len(loader.dataset)
    print(f"=== Epoch {epoch} finished. Avg loss: {epoch_loss / n_samples:.4f} "
            f"(recon {epoch_recon / n_samples:.4f}, kld {epoch_kld / n_samples:.4f}) ===")

    # 保存样本
    if epoch % sample_every == 0 or epoch == 1:
        # 保存检查点
        ckpt_path = os.path.join(save_dir, f"vae_epoch{epoch}.pth")
        save_checkpoint(model, optim, epoch, ckpt_path)
        model.eval()
        with torch.no_grad():
            # 重建样本
            imgs = next(iter(loader))
            imgs = imgs.to(device)[:num_sample_images]
            recon_imgs, _, _ = model(imgs)
            combined = torch.cat([imgs, recon_imgs], dim=0)  # 不再需要clamp
            save_image_grid(combined, os.path.join(save_dir, f"recon_epoch{epoch}.png"), nrow=8)

            # 生成样本
            z = torch.randn(num_sample_images, latent_dim).to(device)
            samples = model.decode(z)
            save_image_grid(samples, os.path.join(save_dir, f"sample_epoch{epoch}.png"), nrow=8)
        model.train()

print("Training complete.")

  0%|          | 1/396 [00:01<09:41,  1.47s/it]

[2025-10-06 09:07:41] Epoch 1/100 Batch 0/396 Loss 1135.2826 (recon 1135.2656, kld 0.0170)


 25%|██▌       | 100/396 [00:48<02:05,  2.36it/s]

[2025-10-06 09:08:28] Epoch 1/100 Batch 100/396 Loss 640.2681 (recon 612.3834, kld 27.8847)


 51%|█████     | 200/396 [01:36<02:06,  1.55it/s]

[2025-10-06 09:09:16] Epoch 1/100 Batch 200/396 Loss 415.4881 (recon 371.2596, kld 44.2285)


 76%|███████▌  | 301/396 [02:21<00:36,  2.61it/s]

[2025-10-06 09:10:01] Epoch 1/100 Batch 300/396 Loss 345.3352 (recon 296.2002, kld 49.1350)


100%|██████████| 396/396 [03:05<00:00,  2.13it/s]


=== Epoch 1 finished. Avg loss: 1.0048 (recon 0.9319, kld 0.0729) ===


  1%|          | 2/396 [00:01<04:19,  1.52it/s]

[2025-10-06 09:10:49] Epoch 2/100 Batch 0/396 Loss 313.2043 (recon 263.1500, kld 50.0544)


 26%|██▌       | 101/396 [00:47<02:08,  2.29it/s]

[2025-10-06 09:11:35] Epoch 2/100 Batch 100/396 Loss 275.7783 (recon 226.1701, kld 49.6082)


 51%|█████     | 201/396 [01:32<01:36,  2.02it/s]

[2025-10-06 09:12:20] Epoch 2/100 Batch 200/396 Loss 275.4307 (recon 225.6817, kld 49.7490)


 76%|███████▌  | 301/396 [02:21<00:45,  2.08it/s]

[2025-10-06 09:13:09] Epoch 2/100 Batch 300/396 Loss 261.7190 (recon 212.4102, kld 49.3087)


100%|██████████| 396/396 [03:04<00:00,  2.15it/s]


=== Epoch 2 finished. Avg loss: 0.5379 (recon 0.4418, kld 0.0961) ===


  0%|          | 1/396 [00:00<06:23,  1.03it/s]

[2025-10-06 09:13:53] Epoch 3/100 Batch 0/396 Loss 249.8823 (recon 200.6285, kld 49.2538)


 26%|██▌       | 101/396 [00:47<02:58,  1.66it/s]

[2025-10-06 09:14:39] Epoch 3/100 Batch 100/396 Loss 246.0801 (recon 194.5008, kld 51.5793)


 51%|█████     | 201/396 [01:32<01:15,  2.58it/s]

[2025-10-06 09:15:25] Epoch 3/100 Batch 200/396 Loss 229.2155 (recon 178.2279, kld 50.9876)


 76%|███████▌  | 301/396 [02:18<00:40,  2.37it/s]

[2025-10-06 09:16:10] Epoch 3/100 Batch 300/396 Loss 230.8549 (recon 178.9533, kld 51.9017)


100%|██████████| 396/396 [03:03<00:00,  2.16it/s]


=== Epoch 3 finished. Avg loss: 0.4586 (recon 0.3590, kld 0.0996) ===


  0%|          | 1/396 [00:01<06:51,  1.04s/it]

[2025-10-06 09:16:56] Epoch 4/100 Batch 0/396 Loss 228.4072 (recon 176.4426, kld 51.9646)


 26%|██▌       | 101/396 [00:47<02:05,  2.36it/s]

[2025-10-06 09:17:43] Epoch 4/100 Batch 100/396 Loss 217.3139 (recon 165.6172, kld 51.6967)


 51%|█████     | 201/396 [01:36<01:33,  2.08it/s]

[2025-10-06 09:18:32] Epoch 4/100 Batch 200/396 Loss 222.6495 (recon 170.6667, kld 51.9827)


 76%|███████▌  | 301/396 [02:21<00:42,  2.26it/s]

[2025-10-06 09:19:17] Epoch 4/100 Batch 300/396 Loss 212.5311 (recon 159.8948, kld 52.6363)


100%|██████████| 396/396 [03:27<00:00,  1.91it/s]


=== Epoch 4 finished. Avg loss: 0.4264 (recon 0.3240, kld 0.1023) ===


  1%|          | 2/396 [00:01<04:47,  1.37it/s]

[2025-10-06 09:20:24] Epoch 5/100 Batch 0/396 Loss 210.4032 (recon 158.2461, kld 52.1572)


 25%|██▌       | 100/396 [00:49<02:07,  2.32it/s]

[2025-10-06 09:21:12] Epoch 5/100 Batch 100/396 Loss 214.3940 (recon 160.8794, kld 53.5146)


 51%|█████     | 201/396 [01:36<01:47,  1.82it/s]

[2025-10-06 09:21:59] Epoch 5/100 Batch 200/396 Loss 208.2413 (recon 154.6446, kld 53.5968)


 76%|███████▌  | 301/396 [02:24<00:49,  1.90it/s]

[2025-10-06 09:22:47] Epoch 5/100 Batch 300/396 Loss 212.5979 (recon 158.4238, kld 54.1742)


100%|██████████| 396/396 [03:08<00:00,  2.10it/s]


=== Epoch 5 finished. Avg loss: 0.4094 (recon 0.3045, kld 0.1049) ===


  1%|          | 2/396 [00:01<03:33,  1.85it/s]

[2025-10-06 09:23:34] Epoch 6/100 Batch 0/396 Loss 205.1887 (recon 151.0585, kld 54.1301)


 26%|██▌       | 101/396 [00:47<02:16,  2.16it/s]

[2025-10-06 09:24:20] Epoch 6/100 Batch 100/396 Loss 208.6403 (recon 154.1043, kld 54.5361)


 51%|█████     | 201/396 [01:34<01:39,  1.96it/s]

[2025-10-06 09:25:07] Epoch 6/100 Batch 200/396 Loss 200.9217 (recon 146.0340, kld 54.8877)


 76%|███████▌  | 301/396 [02:20<00:41,  2.31it/s]

[2025-10-06 09:25:54] Epoch 6/100 Batch 300/396 Loss 203.7509 (recon 148.4366, kld 55.3143)


100%|██████████| 396/396 [03:03<00:00,  2.16it/s]


=== Epoch 6 finished. Avg loss: 0.3989 (recon 0.2915, kld 0.1074) ===


  0%|          | 1/396 [00:01<07:59,  1.21s/it]

[2025-10-06 09:26:37] Epoch 7/100 Batch 0/396 Loss 202.6416 (recon 147.2393, kld 55.4023)


 26%|██▌       | 101/396 [00:47<02:03,  2.38it/s]

[2025-10-06 09:27:24] Epoch 7/100 Batch 100/396 Loss 202.9415 (recon 147.1316, kld 55.8100)


 51%|█████     | 201/396 [01:34<01:44,  1.86it/s]

[2025-10-06 09:28:11] Epoch 7/100 Batch 200/396 Loss 200.4318 (recon 144.4507, kld 55.9811)


 76%|███████▌  | 301/396 [02:21<00:40,  2.33it/s]

[2025-10-06 09:28:57] Epoch 7/100 Batch 300/396 Loss 200.3324 (recon 143.6122, kld 56.7202)


 94%|█████████▍| 373/396 [02:54<00:10,  2.14it/s]


KeyboardInterrupt: 