In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset
import os
from PIL import Image
from torchvision import transforms

In [2]:
# --------------------
# Encoder
# --------------------
class ConvEncoder(nn.Module):
    def __init__(self, in_channels=3, latent_dim=128, hidden_dims=None):
        super().__init__()
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]  # 适合64x64

        modules = []
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, h_dim,
                              kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU(inplace=True)
                )
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)

        # 输入 (3,64,64) -> 输出 (512, 2, 2)，展平就是 512*2*2=2048
        self.fc_mu = nn.Linear(hidden_dims[-1] * 2 * 2, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dims[-1] * 2 * 2, latent_dim)

    def forward(self, x):
        result = self.encoder(x)
        result = torch.flatten(result, start_dim=1)
        mu = self.fc_mu(result)
        log_var = self.fc_logvar(result)
        return mu, log_var


In [3]:
# --------------------
# Decoder
# --------------------
class ConvDecoder(nn.Module):
    def __init__(self, out_channels=3, latent_dim=128, hidden_dims=None):
        super().__init__()
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        self.hidden_dims = hidden_dims
        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 2 * 2)

        modules = []
        hidden_dims = hidden_dims[::-1]  # 反转

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU(inplace=True)
                )
            )
        self.decoder = nn.Sequential(*modules)

        # 最后一层恢复到64x64
        self.final_layer = nn.Sequential(
            nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1],
                               kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(hidden_dims[-1]),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(hidden_dims[-1], out_channels=out_channels,
                      kernel_size=3, padding=1),
            nn.Tanh()  # 输出范围 [-1,1]，可改成 Sigmoid 取 [0,1]
        )

    def forward(self, z):
        result = self.decoder_input(z)
        result = result.view(-1, self.hidden_dims[-1], 2, 2)  # (B, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

In [4]:
# Define the Custom Dataset
class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None, image_num=-1):
        # Collects image file paths from the root directory, limited to `image_num` images.
        self.image_paths = sorted(
            [os.path.join(root_dir, fname) for fname in os.listdir(root_dir)
             if fname.lower().endswith(('.png', '.jpg', '.jpeg'))]
        )[:image_num]  # Limit to the first `image_num` images
        self.transform = transform # Transformation to apply to images

    def __len__(self):
        # Returns the number of images in the dataset.
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Loads an image by index, converts it to RGB, and applies transformations if provided.
        image = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            image = self.transform(image)
        # Since no actual labels, return 0 as dummy labels
        return image, 0

In [5]:
# Transformation pipeline for data augmentation and normalization
transform = transforms.Compose([
    transforms.Resize(size=(64, 64)),
    # Randomly applies a horizontal flip with 40% probability.
    transforms.RandomApply([
        transforms.RandomHorizontalFlip(),
    ], p=0.4),  
    transforms.ToTensor(), # Converts image to a PyTorch tensor.
    transforms.Normalize(mean=(0.5, 0.5, 0.5),
                         std=(0.5, 0.5, 0.5)), # Normalizes using mean and std
])

In [6]:
dataset = ImageDataset(root_dir="/datasets/delkon/dm_data", transform=transform)

In [7]:
encoder = ConvEncoder()
decoder = ConvDecoder()

In [8]:
from d2lightrainer.UnsupervisedLearning.VAE.trainer_config import VAETrainerConfig
from d2lightrainer.UnsupervisedLearning.VAE.trainer import VAETrainer

In [9]:
vae_cfg = VAETrainerConfig()
new_param_dict = {"device": 3, "save_dir": "runs_vae", "batch_size": 16, "nominal_batch_size": 64}
vae_cfg.update(**new_param_dict)

In [10]:
vae_trainer = VAETrainer([encoder, decoder], dataset, vae_cfg)
vae_trainer.train()

2025-09-07 19:36:58,884 - INFO - Using GPU: 3
2025-09-07 19:36:58,892 - INFO - 'optimizer:' Adam(lr=0.0003, momentum=0.937) with parameter groups 10 weight(decay=0.0), 0 weight(decay=1.0000000000000002e-06), 14 weight(decay=0.0001), 24 bias(decay=0.0)
2025-09-07 19:36:59,001 - INFO - --------------------

0/800: 100%|██████████| 216/216 [00:06<00:00, 35.27it/s]
2025-09-07 19:37:05,256 - INFO - All types `lr` of epoch 0: {'lr/param_group0': np.float64(0.0005850694444444444), 'lr/param_group1': np.float64(1.4930555555555555e-05), 'lr/param_group2': np.float64(1.4930555555555555e-05), 'lr/param_group3': np.float64(1.4930555555555555e-05)}
- lr/param_group0: regular weights (full weight decay applied)
- lr/param_group1: batchnorm and logit_scale parameters (no weight decay)
- lr/param_group2: embedding layer weights (smaller weight decay)
- lr/param_group3: bias parameters (no weight decay)
2025-09-07 19:37:05,257 - INFO - epoch 0: train loss 0.34191476457096914
2025-09-07 19:37:05,259 - I