In [None]:
# Jupyter dirty hack
import os
import sys

PATH = "/home/emaballarin/repositories/celeba_sweeping_cvae/src/"
os.chdir(PATH)
sys.path.append(PATH)

In [None]:
from typing import Tuple

In [None]:
from tqdm.auto import tqdm, trange
import torch as th
from torchvision.datasets import CelebA
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, ToTensor
import matplotlib.pyplot as plt

In [None]:
from models import CelebACVAE
from losses import beta_reco_bce

In [None]:
DEVICE_AUTODETECT: bool = True
IMG_SHAPE: Tuple[int, int, int] = (3, 64, 64)
TRAIN_BS: int = 64
TEST_BS: int = 32
LATENT_SIZE: int = 128
CONDITION_SIZE: int = 40
EPOCHS: int = 200
BASE_LR = 1e-3

In [None]:
device = th.device("cuda" if (th.cuda.is_available() and DEVICE_AUTODETECT) else "cpu")

In [None]:
train_ds = CelebA(
    root=PATH + "/../data/celeba",
    split="train",
    target_type="attr",
    transform=Compose([Resize(IMG_SHAPE[1:]), ToTensor()]),
    download=True,
)

test_ds = CelebA(
    root=PATH + "/../data/celeba",
    split="test",
    target_type="attr",
    transform=Compose([Resize(IMG_SHAPE[1:]), ToTensor()]),
    download=True,
)

train_dl = DataLoader(
    train_ds,
    batch_size=TRAIN_BS,
    shuffle=True,
    num_workers=16,
    pin_memory=(device == th.device("cuda")),
)
test_dl = DataLoader(
    test_ds,
    batch_size=TEST_BS,
    shuffle=True,
    num_workers=4,
    pin_memory=(device == th.device("cuda")),
)

In [None]:
model = CelebACVAE(lat_size=LATENT_SIZE, cond_size=CONDITION_SIZE).to(device)

In [None]:
optimizer = th.optim.RAdam(model.parameters(), lr=BASE_LR)
scheduler = th.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=BASE_LR,
    total_steps=EPOCHS,
    pct_start=0.75,
    anneal_strategy="linear",
    cycle_momentum=False,
    base_momentum=0,
    max_momentum=0,
    div_factor=1,
)

In [None]:
model.train()
for epoch in trange(EPOCHS):
    beta: float = (0.5 * (epoch / (EPOCHS // 3))) if epoch < (EPOCHS // 3) else 0.5
    for i, (images, attr) in tqdm(enumerate(train_dl), total=len(train_dl)):
        images: th.Tensor = images.to(device)
        attr: th.Tensor = attr.to(device)
        optimizer.zero_grad()
        reconstructed_image, mean, log_var = model(images, attr)
        loss = beta_reco_bce(reconstructed_image, images, mean, log_var, beta)
        loss.backward()
        optimizer.step()
    scheduler.step()

In [None]:
from torchinfo import summary

In [None]:
summary(
    model,
    input_size=[(TRAIN_BS, *IMG_SHAPE), (TRAIN_BS, CONDITION_SIZE)],
    device=device,
    mode="train",
)