In [3]:
%pip install gdown

Defaulting to user installation because normal site-packages is not writeable
Collecting gdown
  Downloading gdown-5.2.0-py3-none-any.whl.metadata (5.8 kB)
Downloading gdown-5.2.0-py3-none-any.whl (18 kB)
Installing collected packages: gdown
Successfully installed gdown-5.2.0
Note: you may need to restart the kernel to use updated packages.




In [None]:
"""
Single-file ConvVAE with a ResNet-style decoder.

- Default dataset: CIFAR-10 (3x32x32).
- To switch to CelebA (aligned) at 64x64, set DATASET_NAME = 'celeba'.

Loss: MSE + beta * KL (with beta warm-up).

Usage:
  - Run as-is to train on CIFAR-10.
  - Switch to CelebA by setting DATASET_NAME='celeba'. Downloads ~1.3GB.
  - Visualizations: reconstructions + samples after training.
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# ----------------------------
# Config
# ----------------------------
DATASET_NAME = 'celeba'  # 'cifar10' or 'celeba'
SEED = 42

torch.manual_seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if DATASET_NAME.lower() == 'celeba':
    IMG_SIZE = 64
    IN_CH = 3
    LATENT_DIM = 128
    BATCH_SIZE = 64
    EPOCHS = 60
else:
    IMG_SIZE = 32
    IN_CH = 3
    LATENT_DIM = 64
    BATCH_SIZE = 128
    EPOCHS = 50

LR = 1e-3
BETA_START, BETA_END = 0.0, 1.0
WARMUP_EPOCHS = 10

# ----------------------------
# Data
# ----------------------------
if DATASET_NAME.lower() == 'celeba':
    # CelebA aligned faces; we'll center-crop and resize to 64x64
    transform = transforms.Compose([
        transforms.CenterCrop(140),
        transforms.Resize(64),
        transforms.ToTensor(),
    ])
    train_set = datasets.CelebA(
        root='./data', split='train', target_type='attr', download=True, transform=transform
    )
else:
    # CIFAR-10 at 32x32
    transform = transforms.ToTensor()
    train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

# ----------------------------
# Building blocks
# ----------------------------
class UpResBlock(nn.Module):
    """Residual upsampling block: nearest-neighbor upsample + (Conv-BN-ReLU)x2 with skip.
    in_ch -> out_ch, spatial size x2.
    """
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.skip = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x_up = self.upsample(x)
        out = F.relu(self.bn1(self.conv1(x_up)))
        out = self.bn2(self.conv2(out))
        skip = self.skip(x_up)
        out = F.relu(out + skip)
        return out

class DownBlock(nn.Module):
    """Simple strided conv downsampling block with BN+ReLU."""
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1)
        self.bn = nn.BatchNorm2d(out_ch)

    def forward(self, x):
        return F.relu(self.bn(self.conv(x)))

# ----------------------------
# Encoder (shared for 32 or 64)
# ----------------------------
class ConvEncoder(nn.Module):
    def __init__(self, in_ch: int, img_size: int, latent_dim: int):
        super().__init__()
        # compute how many downsamples to get to 4x4
        n_down = int(math.log2(img_size)) - 2  # 32->3, 64->4
        chs = [64, 128, 256, 512][:n_down]

        blocks = []
        prev = in_ch
        for c in chs:
            blocks.append(DownBlock(prev, c))
            prev = c
        self.net = nn.Sequential(*blocks)
        self.flatten = nn.Flatten()
        self.final_ch = chs[-1]
        self.fc_mu = nn.Linear(self.final_ch * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(self.final_ch * 4 * 4, latent_dim)

    def forward(self, x):
        h = self.net(x)  # (B, final_ch, 4, 4)
        h = self.flatten(h)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

# ----------------------------
# ResNet-style Decoder (upsampling with residual blocks)
# ----------------------------
class ResNetDecoder(nn.Module):
    def __init__(self, out_ch: int, img_size: int, latent_dim: int, base_ch_seq=None):
        super().__init__()
        # Determine the channel progression to mirror the encoder
        n_up = int(math.log2(img_size)) - 2  # 32->3, 64->4
        if base_ch_seq is None:
            base_ch_seq = [64, 128, 256, 512][:n_up]  # encoder channels
        # Start from the deepest channel
        deep_ch = base_ch_seq[-1]

        self.fc = nn.Linear(latent_dim, deep_ch * 4 * 4)

        # Build upsampling stages from deep -> shallow
        ups = []
        in_ch = deep_ch
        for ch in reversed(base_ch_seq[:-1]):  # e.g., [256, 128] for 32x32
            ups.append(UpResBlock(in_ch, ch))
            in_ch = ch
        # Final stage to reach the target spatial size
        ups.append(UpResBlock(in_ch, base_ch_seq[0]))
        in_ch = base_ch_seq[0]

        self.ups = nn.Sequential(*ups)
        self.to_img = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)  # logits / mean

    def forward(self, z):
        h = self.fc(z)
        h = h.view(-1, h.shape[1] // (4*4), 4, 4)
        h = self.ups(h)
        pred = self.to_img(h)
        return pred

# ----------------------------
# VAE wrapper
# ----------------------------
class ConvVAE(nn.Module):
    def __init__(self, in_ch: int, img_size: int, latent_dim: int):
        super().__init__()
        self.encoder = ConvEncoder(in_ch, img_size, latent_dim)
        # match decoder channel plan to encoder
        n_down = int(math.log2(img_size)) - 2
        base_ch_seq = [64, 128, 256, 512][:n_down]
        self.decoder = ResNetDecoder(out_ch=in_ch, img_size=img_size, latent_dim=latent_dim, base_ch_seq=base_ch_seq)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        pred = self.decoder(z)
        return pred, mu, logvar

# ----------------------------
# Loss: MSE + beta*KL
# ----------------------------
def vae_loss_mse(pred, x, mu, logvar, beta=1.0):
    recon = F.mse_loss(pred, x, reduction='sum')
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon + beta * kl, recon, kl

# ----------------------------
# Train
# ----------------------------
model = ConvVAE(IN_CH, IMG_SIZE, LATENT_DIM).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR, betas=(0.9, 0.99))

train_losses, recon_losses, kl_losses = [], [], []

model.train()
for epoch in range(EPOCHS):
    beta = BETA_START + (BETA_END - BETA_START) * min(1.0, epoch / WARMUP_EPOCHS)
    total, tot_recon, tot_kl = 0.0, 0.0, 0.0

    for x, _ in train_loader:
        x = x.to(device)
        optimizer.zero_grad()
        pred, mu, logvar = model(x)
        loss, r, k = vae_loss_mse(pred, x, mu, logvar, beta=beta)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total += loss.item()
        tot_recon += r.item()
        tot_kl += k.item()

    N = len(train_loader.dataset)
    train_losses.append(total / N)
    recon_losses.append(tot_recon / N)
    kl_losses.append(tot_kl / N)

    print(f"Epoch {epoch+1:02d} | beta={beta:.2f} | Loss:{total/N:.2f}  Recon:{tot_recon/N:.2f}  KL:{tot_kl/N:.2f}")

# ----------------------------
# Plot losses
# ----------------------------
plt.figure()
plt.plot(train_losses, label='Total')
plt.plot(recon_losses, label='Recon')
plt.plot(kl_losses, label='KL')
plt.title(f"{DATASET_NAME.upper()} ConvVAE (ResNet decoder) Training")
plt.xlabel("Epoch")
plt.ylabel("Per-sample loss")
plt.grid(True)
plt.legend()
plt.show()

# ----------------------------
# Reconstructions
# ----------------------------
model.eval()
with torch.no_grad():
    x, _ = next(iter(train_loader))
    x = x[:8].to(device)
    pred, _, _ = model(x)
    view = pred.sigmoid().clamp(0, 1).cpu()  # for nicer viewing only

fig, axs = plt.subplots(2, 8, figsize=(12, 4))
for i in range(8):
    axs[0, i].imshow(x[i].permute(1, 2, 0).detach().cpu())
    axs[0, i].axis('off')
    axs[1, i].imshow(view[i].permute(1, 2, 0))
    axs[1, i].axis('off')
plt.suptitle("Top: originals | Bottom: reconstructions")
plt.show()

# ----------------------------
# Sampling
# ----------------------------
with torch.no_grad():
    z = torch.randn(16, LATENT_DIM, device=device)
    logits = model.decoder(z)
    samples = logits.sigmoid().clamp(0, 1).cpu()

fig, axes = plt.subplots(2, 8, figsize=(12, 4))
for i, ax in enumerate(axes.flat):
    ax.imshow(samples[i].permute(1, 2, 0))
    ax.axis('off')
plt.suptitle("ConvVAE Samples (ResNet-style decoder)")
plt.show()

Downloading...
From (original): https://drive.google.com/uc?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM
From (redirected): https://drive.usercontent.google.com/download?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM&confirm=t&uuid=f3e66224-3d1f-498f-8de1-e831d86e287c
To: C:\Users\sarat\data\celeba\img_align_celeba.zip
100%|█████████████████████████████████████████████████████████████████████████████| 1.44G/1.44G [00:26<00:00, 55.0MB/s]
Downloading...
From: https://drive.google.com/uc?id=0B7EVK8r0v71pblRyaVFSWGxPY0U
To: C:\Users\sarat\data\celeba\list_attr_celeba.txt
100%|█████████████████████████████████████████████████████████████████████████████| 26.7M/26.7M [00:00<00:00, 33.7MB/s]
Downloading...
From: https://drive.google.com/uc?id=1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS
To: C:\Users\sarat\data\celeba\identity_CelebA.txt
100%|█████████████████████████████████████████████████████████████████████████████| 3.42M/3.42M [00:00<00:00, 17.5MB/s]
Downloading...
From: https://drive.google.com/uc?id=0B7EVK8r0v71pbThiMVRxWXZ4

Epoch 01 | beta=0.00 | Loss:177.96  Recon:177.96  KL:5753.33
