In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
DATAPATH = './models/'
# GPU training
DEVICE = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
print(f"Training on device {DEVICE}.")
torch.cuda.memory_allocated()

In [None]:
import datetime as dt
import numpy as np

def cgan_training_loop(n_epochs, optimizers, generator, discriminator, train_loader, valid_loader, reg=1, smoothing=0.0,
                      max_patience=float('Inf')):
    best_valid_loss = float('Inf')
    patience = max_patience
    bce = nn.BCELoss()
    dif = nn.L1Loss()
    print('{} starting training'.format(dt.datetime.now()))
    for epoch in range(1, n_epochs + 1):
        train_loss = 0.0
        # We alternate between one gradient descent step on D, then one step on G.
        isTrainedD = True
        isTrainedG = not isTrainedD
        for colored, grayscale in train_loader:
            colored = colored.to(device=DEVICE)
            grayscale = grayscale.to(device=DEVICE)
            # Ascend D gradient XOR Descend G gradient
            discriminator.train(mode=isTrainedD)
            generator.train(mode=isTrainedG)
            # (Un)Freeze corresponding weights
            for p in generator.parameters():
                p.requires_grad = isTrainedG
            for p in discriminator.parameters():
                p.requires_grad = isTrainedD
            g_output = generator(grayscale)
            fake = discriminator(g_output, grayscale)
            if isTrainedD:
                # One-sided smoothing for discriminator
                real = discriminator(colored, grayscale)
                #g_output = g_output.detach() # https://github.com/pytorch/pytorch/issues/39141#issuecomment-636881953
                loss = 0.5 * (bce(fake, torch.zeros(fake.size()).to(device=DEVICE)) +\
                              bce(real, (torch.ones(real.size()) - smoothing).to(device=DEVICE)))
                opt_idx = 1
            else:
                # Training Trick for generator
                loss = -1 * bce(fake, torch.zeros(fake.size()).to(device=DEVICE)) +\
                    reg * dif(colored, g_output)
                opt_idx = 0
            optimizers[opt_idx].zero_grad()
            loss.backward()
            optimizers[opt_idx].step()
            train_loss += loss.item()
            isTrainedD = not isTrainedD
            isTrainedG = not isTrainedG
        torch.cuda.empty_cache()
        # Validation
        with torch.no_grad():
            generator.train(mode=False)
            discriminator.train(mode=False)
            valid_loss = 0.0
            for colored, grayscale in valid_loader:
                colored = colored.to(device=DEVICE)
                grayscale = grayscale.to(device=DEVICE)
                g_output = generator(grayscale)
                real = discriminator(colored, grayscale)
                fake = discriminator(g_output, grayscale)
                loss = 0.5 * (bce(fake, torch.zeros(fake.size()).to(device=DEVICE)) +\
                              bce(real, torch.ones(real.size()).to(device=DEVICE))) +\
                    reg * dif(colored, g_output)
                valid_loss += loss.item()
        torch.cuda.empty_cache()
        # Verbose training
        train_loss = train_loss / len(train_loader)
        valid_loss = valid_loss / len(valid_loader)
        if epoch == 1 or epoch % 10 == 0:
            print('{} Epoch {}, Train {:.5f}, Valid {:.5f}'.format(dt.datetime.now(),
                                                                      epoch,
                                                                      train_loss,
                                                                      valid_loss))
        if valid_loss < best_valid_loss:
            torch.save(generator.state_dict(), DATAPATH + 'G_lastchance{}reg.pt'.format(reg))
            print(f'Saving {epoch}-th for {valid_loss = :2.5f}')
            best_valid_loss = valid_loss
            patience = max_patience
        elif patience == 0:
            return
        else:
            patience = patience - 1


In [None]:
from cocoLoader import load_coco_dataset

BATCH_SIZE = 64
IMG_SIZE = 286
# Select image categories to train colorization for
train_dataloader, valid_dataloader = load_coco_dataset(batch_size=BATCH_SIZE,
                                                       size=30000,
                                                       dim=IMG_SIZE,
                                                       cats=('person'))


In [None]:
from UNet import UNet
from PatchGAN import PatchGAN

unet = UNet(in_channels=1, out_channels=3).to(device=DEVICE)
#unet.load_state_dict(torch.load(DATAPATH + 'G_lastchance100reg.pt', map_location=DEVICE))
patchgan = PatchGAN().to(device=DEVICE)
g_optimizer = torch.optim.Adam(unet.parameters(), lr=2e-4)
d_optimizer = torch.optim.SGD(patchgan.parameters(), lr=2e-4)
cgan_training_loop(
        n_epochs = 1000,
        optimizers = (g_optimizer, d_optimizer),
        generator = unet,
        discriminator = patchgan,
        train_loader = train_dataloader,
        valid_loader = valid_dataloader,
        reg = 300)


In [None]:
# Load best models
unet.load_state_dict(torch.load(DATAPATH + 'G_lastchance100reg.pt', map_location=DEVICE))
unet.eval()
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 60))
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
mean = [-1 * x / y for x, y in zip(mean, std)]
std = [1 / x for x in std]
unnormalize = torchvision.transforms.Normalize(mean=mean, std=std)
for colored, grayscale in valid_dataloader:
    for idx in range(10):
        c = colored[idx]
        g = grayscale[idx]
        plt.subplot(10, 3, 1 + 3 * idx)
        plt.imshow(unnormalize(c).permute(1, 2, 0), vmin=0, vmax=1)
        plt.subplot(10, 3, 2 + 3 * idx)
        plt.imshow(g.permute(1, 2, 0), 'gray')
        generated = unet(g.to(device=DEVICE).unsqueeze(0)).detach().squeeze(0)
        generated = unnormalize(generated)
        plt.subplot(10, 3, 3 + 3 * idx)
        plt.imshow(generated.squeeze(0).cpu().permute(1, 2, 0), vmin=0, vmax=1)
    plt.show()
    break