In [1]:
%run Config.ipynb
%run utils.ipynb
%run Discriminator.ipynb
%run Generator.ipynb
%run dataset.ipynb

torch.Size([5, 1, 30, 30])
torch.Size([2, 3, 256, 256])


In [3]:
import torch
import sys
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchvision.utils import save_image

In [4]:
def train_fn(
    disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler
):
    H_reals = 0
    H_fakes = 0
    loop = tqdm(loader, leave=True)

    for idx, (summer, winter) in enumerate(loop):
        summer = summer.to(Config.DEVICE)
        winter = winter.to(Config.DEVICE)

        with torch.cuda.amp.autocast():
            fake_winter = gen_H(summer)
            D_H_real = disc_H(winter)
            D_H_fake = disc_H(fake_winter.detach())
            H_reals += D_H_real.mean().item()
            H_fakes += D_H_fake.mean().item()
            D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
            D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
            D_H_loss = D_H_real_loss + D_H_fake_loss

            fake_summer = gen_Z(winter)
            D_Z_real = disc_Z(summer)
            D_Z_fake = disc_Z(fake_summer.detach())
            D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
            D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
            D_Z_loss = D_Z_real_loss + D_Z_fake_loss
            D_loss = (D_H_loss + D_Z_loss) / 2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        with torch.cuda.amp.autocast():
            D_H_fake = disc_H(fake_winter)
            D_Z_fake = disc_Z(fake_summer)
            loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
            loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))

            cycle_summer = gen_Z(fake_winter)
            cycle_winter = gen_H(fake_summer)
            cycle_summer_loss = l1(summer, cycle_summer)
            cycle_winter_loss = l1(winter, cycle_winter)

            identity_summer = gen_Z(summer)
            identity_winter = gen_H(winter)
            identity_summer_loss = l1(summer, identity_summer)
            identity_winter_loss = l1(winter, identity_winter)

            G_loss = (
                loss_G_Z
                + loss_G_H
                + cycle_summer_loss * Config.LAMBDA_CYCLE
                + cycle_winter_loss * Config.LAMBDA_CYCLE
                + identity_winter_loss * Config.LAMBDA_IDENTITY
                + identity_summer_loss * Config.LAMBDA_IDENTITY
            )

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 200 == 0:
            save_image(fake_winter * 0.5 + 0.5, f"saved_images/winter_{idx}.png")
            save_image(fake_summer * 0.5 + 0.5, f"saved_images/summer_{idx}.png")

        loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))

In [7]:
def main():
    disc_H = Discriminator(in_channels=3).to(Config.DEVICE)
    disc_Z = Discriminator(in_channels=3).to(Config.DEVICE)
    gen_Z = Generator(img_channels=3, num_residuals=9).to(Config.DEVICE)
    gen_H = Generator(img_channels=3, num_residuals=9).to(Config.DEVICE)
    opt_disc = optim.Adam(
        list(disc_H.parameters()) + list(disc_Z.parameters()),
        lr=Config.LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    opt_gen = optim.Adam(
        list(gen_Z.parameters()) + list(gen_H.parameters()),
        lr=Config.LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    if Config.LOAD_MODEL:
        load_checkpoint(
            Config.CHECKPOINT_GEN_H,
            gen_H,
            opt_gen,
            Config.LEARNING_RATE,
        )
        load_checkpoint(
            Config.CHECKPOINT_GEN_Z,
            gen_Z,
            opt_gen,
            Config.LEARNING_RATE,
        )
        load_checkpoint(
            Config.CHECKPOINT_CRITIC_H,
            disc_H,
            opt_disc,
            Config.LEARNING_RATE,
        )
        load_checkpoint(
            Config.CHECKPOINT_CRITIC_Z,
            disc_Z,
            opt_disc,
            Config.LEARNING_RATE,
        )

    dataset = SummerWinterDataset(
        root_winter="data/train/winter",
        root_summer="data/train/summer",
        transform=Config.transforms,
    )
    val_dataset = SummerWinterDataset(
        root_winter="data/val/winter",
        root_summer="data/val/summer",
        transform=Config.transforms,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
    )
    loader = DataLoader(
        dataset,
        batch_size=Config.BATCH_SIZE,
        shuffle=True,
        num_workers=0,
        pin_memory=True,
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(Config.NUM_EPOCHS):
        train_fn(
            disc_H,
            disc_Z,
            gen_Z,
            gen_H,
            loader,
            opt_disc,
            opt_gen,
            L1,
            mse,
            d_scaler,
            g_scaler,
        )

        if Config.SAVE_MODEL:
            save_checkpoint(gen_H, opt_gen, filename=Config.CHECKPOINT_GEN_H)
            save_checkpoint(gen_Z, opt_gen, filename=Config.CHECKPOINT_GEN_Z)
            save_checkpoint(disc_H, opt_disc, filename=Config.CHECKPOINT_CRITIC_H)
            save_checkpoint(disc_Z, opt_disc, filename=Config.CHECKPOINT_CRITIC_Z)


if __name__ == "__main__":
    main()

  0%|          | 0/1231 [00:47<?, ?it/s]


KeyboardInterrupt: 