In [None]:
import sys
sys.path.append("..")

from models.adversarial.dcgan import DCGenerator, DCDiscriminator, weights_init
from models.adversarial.train_gan import train_gan

import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from utils.data import visualize_dataset, cherry_pick_samples, samples_to_gif
from utils.plots import plot_losses

In [None]:
ts = transforms.Compose([
    transforms.Resize([64,64]),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.5, 0.5, 0.5),
        std=(0.5, 0.5, 0.5),
    )
])

ds = ImageFolder(
    root="../datasets/celeba",
    transform=ts
)

dl = DataLoader(
    ds, 
    batch_size=128, 
    shuffle=True, 
    num_workers=0, 
    pin_memory=False
)

In [None]:
visualize_dataset(dl, 8, 4)

In [None]:
g = DCGenerator(
    nz=100,
    nc=3,
    ngf=64
)
d = DCDiscriminator(
    nc=3,
    ndf=64
)

g.apply(weights_init)
d.apply(weights_init)

generator, discriminator, g_losses, d_losses = train_gan(
    g=g,
    d=d,
    loss_fn=torch.nn.BCELoss(),
    dataloader=dl,
    epochs=5,
    device=torch.device("mps" if torch.backends.mps.is_available() else "cpu"),
    nz=100,
    lr=2e-4,
    beta1=0.5,
    out_dir="../outputs", 
)

In [None]:
import os
from PIL import Image

def samples_to_gif(
    samples_dir,
    out_path="samples.gif",
    duration=200,
    loop=0,
):
    """
    Turn all images in a directory into a GIF.

    Args:
        samples_dir (str): directory containing saved sample images
        out_path (str): output gif path
        duration (int): time per frame in ms
        loop (int): 0 = infinite loop
    """

    # collect image files (sorted!)
    image_files = sorted(
        [
            os.path.join(samples_dir, f)
            for f in os.listdir(samples_dir)
            if f.lower().endswith((".png", ".jpg", ".jpeg"))
        ]
    )

    if len(image_files) == 0:
        raise RuntimeError("No images found in samples directory")

    frames = [Image.open(f).convert("RGB") for f in image_files]

    frames[0].save(
        out_path,
        save_all=True,
        append_images=frames[1:],
        duration=duration,
        loop=loop,
    )

    print(f"GIF saved to: {out_path}")

In [None]:
samples_to_gif(
    samples_dir="../outputs/samples",
    out_path="../outputs/training_progress.gif",
    duration=200,
)

In [None]:
cherry_pick_samples(
    generator=g,
    nz=100,
    device="mps",
    total=25,
    save_path="../outputs/cherry_picked.png",
)

In [None]:
plot_losses(
    {"Generator": g_losses, "Discriminator": d_losses},
    log_y=False,
)