In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
DATAPATH = './models/'
# GPU training
DEVICE = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
print(f"Training on device {DEVICE}.")

In [1]:
import datetime as dt

def cgan_training_loop(n_epochs, optimizers, generator, discriminator, train_loader, valid_loader):
    best_valid_loss = 0.0
    # We alternate between one gradient descent step on D, then one step on G.
    for epoch in range(1, n_epochs + 1):
        g_loss = 0.0
        d_loss = 0.0
        # training loss
        generator.train(mode=True)
        discriminator.train(mode=True)
        for colored, grayscale in train_loader:
            colored = colored.to(device=DEVICE)
            grayscale = grayscale.to(device=DEVICE)
            # Ascend D gradient
            # Freeze generator weights
            for p in generator.parameters():
                p.requires_grad = False
            g_output = generator(colored)
            d_output_real = discriminator(colored, grayscale)
            d_output_fake = discriminator(g_output, grayscale)
            d_loss += -1 * (torch.mean(torch.log(d_output_real) + torch.log(1 - d_output_fake)))
            optimizers[1].zero_grad()
            d_loss.backward()
            optimizers[1].step()
            # Unfreeze generator parameters
            for p in generator.parameters():
                p.requires_grad = True
            # Descend G gradient
            # Freeze discriminator weights
            for p in discriminator.parameters():
                p.requires_grad = False
            # Same sample as in D gradient?
            g_output = generator(colored)
            d_output_fake = discriminator(g_output, grayscale)
            g_loss += torch.mean(torch.log(1 - d_output_fake))
            g_loss += torch.nn.L1loss(colored, g_output)
            optimizers[0].zero_grad()
            g_loss.backward()
            optimizers[0].step()
            # Unfreeze discriminator weights
            for p in discriminator.parameters():
                p.requires_grad = True
            del g_output, d_output_real, d_output_fake
        # validation loss
        with torch.no_grad():
            generator.train(mode=False)
            discriminator.train(mode=False)
            valid_loss = 0.0
            for colored, grayscale in valid_loader:
                colored = colored.to(device=DEVICE)
                grayscale = grayscale.to(device=DEVICE)
                g_output = generator(colored)
                d_output_real = discriminator(colored, grayscale)
                d_output_fake = discriminator(g_output, grayscale)
                valid_loss += torch.mean(torch.log(d_output_real) + torch.log(1 - d_output_fake)) +\
                             torch.nn.L1loss(colored, g_output)
                del g_output, d_output_real, d_output_fake
        # Verbose training
        if epoch == 1 or epoch % 10 == 0:
            print('{} Epoch {}, Training {}, Valid {}'.format(dt.datetime.now(), epoch,
                                                         loss_train / len(train_loader),
                                                         valid_loss / len(valid_loader)))
        # Save best model so far
        if valid_loss > best_valid_loss + 0.5:
            torch.save(generator.state_dict(), DATA_PATH + 'G_{:2.2f}.pt'.format(valid_loss / len(valid_loader)))
            torch.save(discriminator.state_dict(), DATA_PATH + 'D_{:2.2f}.pt'.format(valid_loss / len(valid_loader)))
            best_valid_loss = valid_loss
    pass

In [None]:
from cocoLoader import load_coco_dataset
from UNet import UNet
from PatchGAN import PatchGAN

train_dataloader, valid_dataloader = load_coco_dataset(batch_size=130)
unet = UNet().to(device=DEVICE) #Implement U-Net generator
patchgan = PatchGAN(patch_size=5).to(device=DEVICE)
g_optimizer = torch.optim.Adam(unet.parameters(), lr=1e-2)
d_optimizer = torch.optim.Adam(patchgan.parameters(), lr=1e-2)
cgan_training_loop(
    n_epochs = 100,
    optimizers = (g_optimizer, d_optimizer),
    generator = unet,
    discriminator = patchgan,
    train_loader = train_dataloader,
    valid_loader= valid_dataloader)
torch.cuda.empty_cache()

In [None]:
# Load best models
model.load_state_dict(torch.load(DATAPATH + 'cGAN_{:2.2f}_{:2.2f}.pt'.format(), map_location=DEVICE))