In [1]:
import torch
from dataset import ToStyleDataset
from utils import save_checkpoint, load_checkpoint
from torch.utils.data import DataLoader
import torch.optim as optim
import config
from discriminator import Discriminator
from generator import Generator
from train import train
import warnings

warnings.filterwarnings("ignore")

torch.cuda.is_available()



True

In [2]:
d_real = Discriminator(in_channels=3).to(config.DEVICE)
g_real = Generator(in_channels=3, num_residuals=9).to(config.DEVICE)

d_style = Discriminator(in_channels=3).to(config.DEVICE)
g_style = Generator(in_channels=3, num_residuals=9).to(config.DEVICE)

d_optim = optim.Adam(list(d_style.parameters()) + list(d_real.parameters()), lr=config.LR, betas=(0.5, 0.999))
g_optim = optim.Adam(list(g_style.parameters()) + list(g_real.parameters()), lr=config.LR, betas=(0.5, 0.999))

In [3]:
if config.LOAD_MODEL:
    load_checkpoint("checkpoints/" + config.CHECKPOINT_G_STYLE, g_style, g_optim)
    load_checkpoint("checkpoints/" + config.CHECKPOINT_G_REAL, g_real, g_optim)
    load_checkpoint("checkpoints/" + config.CHECKPOINT_CRITIC_STYLE, d_style, d_optim)
    load_checkpoint("checkpoints/" + config.CHECKPOINT_CRITIC_REAL, d_real, d_optim)

In [4]:
dataset = ToStyleDataset(
    root_real=config.TRAIN_PATH + "/real",
    root_style=config.TRAIN_PATH + "/ghibli",
    transform=config.transform
)

data_loader = DataLoader(dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS,
                         pin_memory=True)
len(dataset)

6287

In [5]:
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

In [6]:
d_loss = []
g_loss = []

In [7]:
for epoch in range(config.NUM_EPOCHS):
    print(f'Epoch #{epoch}')
    dl, gl = train(
        d_real,
        d_style,
        g_real,
        g_style,
        data_loader,
        d_optim,
        g_optim,
        d_scaler,
        g_scaler
    )

    d_loss.extend(dl)
    g_loss.extend(gl)

    if config.SAVE_MODEL:
        save_checkpoint(g_style, g_optim, filename=config.CHECKPOINT_G_STYLE)
        save_checkpoint(g_real, g_optim, filename=config.CHECKPOINT_G_REAL)
        save_checkpoint(d_style, d_optim, filename=config.CHECKPOINT_CRITIC_STYLE)
        save_checkpoint(d_real, d_optim, filename=config.CHECKPOINT_CRITIC_REAL)


Epoch #0


  3%|â–Ž         | 184/6287 [00:49<27:18,  3.73it/s, d_loss=1.25, g_loss=17]   


KeyboardInterrupt: 