## Preprocess

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
from omni.utils.dataset import TensorImageFolder
from omni.utils.device import get_device
from torchinfo import summary

In [None]:
device = get_device()
use_bf16 = device.type == "cuda" and torch.cuda.is_bf16_supported()

train_data = TensorImageFolder("../data/afhq_v2_preprocessed/train")
test_data = TensorImageFolder("../data/afhq_v2_preprocessed/test")

cpu_count = os.cpu_count() or 1
num_workers = min(8, cpu_count // 2)

train_loader = DataLoader(
    train_data,
    batch_size=64,  # you can go higher now
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
)

test_loader = DataLoader(
    test_data,
    batch_size=512,  # you can go higher now
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
)

In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, groups: int):
        super().__init__()

        self.norm1 = nn.GroupNorm(groups, in_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)

        self.norm2 = nn.GroupNorm(groups, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)

        self.skip = (
            nn.Identity()
            if in_ch == out_ch
            else nn.Conv2d(in_ch, out_ch, kernel_size=1)
        )

    def forward(self, x):
        h = self.conv1(F.silu(self.norm1(x)))
        h = self.conv2(F.silu(self.norm2(h)))
        return h + self.skip(x)


class Encoder(nn.Module):
    def __init__(
        self,
        in_channels: int,
        base_channels: int,
        latent_channels: int,
        groups: int,
    ):
        super().__init__()

        self.conv_in = nn.Conv2d(in_channels, base_channels, 3, padding=1)

        self.down1 = nn.Sequential(
            ResBlock(base_channels, base_channels, groups),
            nn.Conv2d(base_channels, base_channels, 4, stride=2, padding=1),
        )

        self.down2 = nn.Sequential(
            ResBlock(base_channels, base_channels * 2, groups),
            nn.Conv2d(base_channels * 2, base_channels * 2, 4, stride=2, padding=1),
        )

        self.down3 = nn.Sequential(
            ResBlock(base_channels * 2, base_channels * 4, groups),
            nn.Conv2d(base_channels * 4, base_channels * 4, 4, stride=2, padding=1),
        )

        self.mid = ResBlock(base_channels * 4, base_channels * 4, groups)

        self.norm_out = nn.GroupNorm(groups, base_channels * 4)
        self.conv_out = nn.Conv2d(base_channels * 4, latent_channels, 3, padding=1)

    def forward(self, x):
        x = self.conv_in(x)
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)
        x = self.mid(x)
        x = self.conv_out(F.silu(self.norm_out(x)))
        return x


class Decoder(nn.Module):
    def __init__(
        self,
        out_channels: int,
        base_channels: int,
        latent_channels: int,
        groups: int,
    ):
        super().__init__()

        self.conv_in = nn.Conv2d(latent_channels, base_channels * 4, 3, padding=1)

        self.mid = ResBlock(base_channels * 4, base_channels * 4, groups)

        self.up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            ResBlock(base_channels * 4, base_channels * 2, groups),
        )

        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            ResBlock(base_channels * 2, base_channels, groups),
        )

        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            ResBlock(base_channels, base_channels, groups),
        )

        self.norm_out = nn.GroupNorm(groups, base_channels)
        self.conv_out = nn.Conv2d(base_channels, out_channels, 3, padding=1)

    def forward(self, z):
        x = self.conv_in(z)
        x = self.mid(x)
        x = self.up1(x)
        x = self.up2(x)
        x = self.up3(x)
        x = self.conv_out(F.silu(self.norm_out(x)))
        return torch.sigmoid(x)


class Autoencoder(nn.Module):
    """
    SDXL-style spatial autoencoder.

    Defaults:
      - 512x512 images
      - Latents: 4 x 64 x 64
      - GroupNorm + residual blocks
      - SD-compatible latent scaling
    """

    def __init__(
        self,
        image_channels: int = 3,
        base_channels: int = 32,
        latent_channels: int = 4,
        groupnorm_groups: int = 32,
        latent_scale: float = 0.18215,
    ):
        super().__init__()

        self.latent_channels = latent_channels
        self.latent_scale = latent_scale

        self.encoder = Encoder(
            in_channels=image_channels,
            base_channels=base_channels,
            latent_channels=latent_channels,
            groups=groupnorm_groups,
        )

        self.decoder = Decoder(
            out_channels=image_channels,
            base_channels=base_channels,
            latent_channels=latent_channels,
            groups=groupnorm_groups,
        )

    def encode(self, x):
        return self.encoder(x) * self.latent_scale

    def decode(self, z):
        return self.decoder(z / self.latent_scale)

    def forward(self, x):
        z = self.encode(x)
        return self.decode(z)


In [None]:
from torchmetrics.functional.image import (
    peak_signal_noise_ratio,
    structural_similarity_index_measure,
)

In [None]:
step = 0

model = Autoencoder().to(device)

optimizer_ae = torch.optim.AdamW(model.parameters(), lr=1e-3)

criterion_recon = nn.MSELoss()

scaler = None  # not used for bf16

print(summary(model, (1, 3, 128, 128), device=device.type))

In [None]:
for epoch in range(10):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}")

    for batch_idx, (images, _) in enumerate(pbar):
        images = images.to(device, non_blocking=True)
        batch_size = images.size(0)

        with torch.autocast(
            device_type="cuda",
            dtype=torch.bfloat16,
            enabled=use_bf16,
        ):
            images = images.float().div_(255.0)
            recon = model(images)

            # ==================
            # Autoencoder step
            # ==================
            optimizer_ae.zero_grad(set_to_none=True)

            recon = model(images)

            # Reconstruction loss
            loss_recon = criterion_recon(recon, images)

            loss_ae = loss_recon
            loss_ae.backward()
            optimizer_ae.step()

        step += 1
        pbar.set_postfix(
            {
                "loss_ae": loss_ae.item(),
                "loss_recon": loss_recon.item(),
                "psnr": peak_signal_noise_ratio(recon, images, data_range=1.0).item(),
                "ssim": structural_similarity_index_measure(
                    recon, images, data_range=1.0
                ).item(),
            }
        )

    # --------------------
    # Eval
    # --------------------
    model.eval()

    total_loss_recon = 0

    with (
        torch.no_grad(),
        torch.autocast(
            device_type="cuda",
            dtype=torch.bfloat16,
            enabled=use_bf16,
        ),
    ):
        for images, _ in test_loader:
            images = images.to(device, non_blocking=True)
            images = images.float().div_(255.0)
            recon = model(images)

            loss_recon = criterion_recon(recon, images)

            total_loss_recon += loss_recon.item() * images.size(0)

    print(f"Test loss (recon): {total_loss_recon / len(test_loader.dataset):.6f}")


In [None]:
import matplotlib.pyplot as plt

with torch.no_grad():
    images, _ = next(iter(test_loader))
    images = images[:32].to(device, non_blocking=True)
    images = images.float().div_(255.0)
    recon = model(images)

from torchvision.utils import make_grid

grid = make_grid(torch.cat([images.cpu(), recon.cpu()], dim=0), nrow=16)
plt.figure(figsize=(24, 16))
plt.imshow(grid.permute(1, 2, 0))
plt.axis("off")
plt.show()

In [None]:
grid.min(), grid.max()