In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
DATAPATH = './models/'
# GPU training
DEVICE = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
print(f"Training on device {DEVICE}.")
torch.cuda.memory_allocated()

Training on device cuda.


0

In [2]:
import datetime as dt

def cgan_training_loop(n_epochs, optimizers, generator, discriminator, train_loader, valid_loader, reg=0.01):
    best_valid_loss = 10.0
    # We alternate between one gradient descent step on D, then one step on G.
    for epoch in range(1, n_epochs + 1):
        train_loss = torch.empty(1).to(torch.float64)
        generator.train(mode=True)
        discriminator.train(mode=True)
        for colored, grayscale in train_loader:
            colored = colored.to(device=DEVICE)
            grayscale = grayscale.to(device=DEVICE)
            # Ascend D gradient
            # Freeze G weights
            for p in generator.parameters():
                p.requires_grad = False
            g_output = generator(grayscale)
            d_output_real = discriminator(colored, grayscale)
            d_output_fake = discriminator(g_output, grayscale)
            loss = CGANDiscriminatorLoss()(d_output_real, d_output_fake)
            optimizers[1].zero_grad()
            loss.backward()
            optimizers[1].step()
            train_loss += loss.item()
            # Unfreeze G parameters
            for p in generator.parameters():
                p.requires_grad = True
            # Descend G gradient
            # Freeze D weights
            for p in discriminator.parameters():
                p.requires_grad = False
            # Same sample as in D gradient?
            g_output = generator(grayscale)
            d_output_fake = discriminator(g_output, grayscale)
            loss = CGANGeneratorLoss()(d_output_fake) + reg * torch.nn.L1Loss()(colored, g_output)
            optimizers[0].zero_grad()
            loss.backward()
            optimizers[0].step()
            train_loss += loss.item() - CGANGeneratorLoss()(d_output_fake).item()
            # Unfreeze D weights
            for p in discriminator.parameters():
                p.requires_grad = True
            del g_output, d_output_real, d_output_fake, colored, grayscale, loss
            torch.cuda.empty_cache()
        # Validation
        with torch.no_grad():
            generator.train(mode=False)
            discriminator.train(mode=False)
            valid_loss = 0.0
            for colored, grayscale in valid_loader:
                colored = colored.to(device=DEVICE)
                grayscale = grayscale.to(device=DEVICE)
                g_output = generator(grayscale)
                d_output_real = discriminator(colored, grayscale)
                d_output_fake = discriminator(g_output, grayscale)
                loss = CGANDiscriminatorLoss()(d_output_real, d_output_fake) +\
                             reg * torch.nn.L1Loss()(colored, g_output)
                valid_loss += loss.item()
                del g_output, d_output_real, d_output_fake, colored, grayscale, loss
                torch.cuda.empty_cache()
        # Verbose training
        train_loss = train_loss / len(train_loader)
        valid_loss = valid_loss / len(valid_loader)
        if epoch == 1 or epoch % 100 == 0:
            print('{} Epoch {}, Train {:.5f}, Valid {:.5f}'.format(dt.datetime.now(),
                                                                      epoch,
                                                                      train_loss.numpy()[0],
                                                                      valid_loss))
            torch.save(generator.state_dict(), DATAPATH + 'G_{:05d}epoch_{:2.2f}valLoss.pt'.format(epoch, valid_loss))
            print(f'Saving {epoch}-th for {valid_loss = :2.5f}')
        # Save best model so far
        if valid_loss < best_valid_loss - 1:
            torch.save(generator.state_dict(), DATAPATH + 'G_best.pt')
            print(f'Saving best for {valid_loss = :2.5f}')
            best_valid_loss = valid_loss


In [None]:
from cocoLoader import load_coco_dataset
from UNet import UNet
from PatchGAN import PatchGAN
from cGANloss import CGANDiscriminatorLoss, CGANGeneratorLoss

train_dataloader, valid_dataloader = load_coco_dataset(batch_size=20, size=10000)
unet = UNet(output_size=256, in_channels=1, out_channels=3).to(device=DEVICE)
patchgan = PatchGAN().to(device=DEVICE)
g_optimizer = torch.optim.Adam(unet.parameters(), lr=1e-3)
d_optimizer = torch.optim.Adam(patchgan.parameters(), lr=1e-3)
cgan_training_loop(
    n_epochs = 100,
    optimizers = (g_optimizer, d_optimizer),
    generator = unet,
    discriminator = patchgan,
    train_loader = train_dataloader,
    valid_loader= valid_dataloader,
    reg = 100)
torch.cuda.empty_cache()

In [None]:
# Load best models
unet.load_state_dict(torch.load(DATAPATH + 'G_00001epoch_42.64valLoss.pt', map_location=DEVICE))
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 60))
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
mean = [-1 * x / y for x, y in zip(mean, std)]
std = [1 / x for x in std]
unnormalize = torchvision.transforms.Normalize(mean=mean, std=std)
for colored, grayscale in valid_dataloader:
    for idx in range(10):
        c = colored[idx]
        g = grayscale[idx]
        plt.subplot(10, 3, 1 + 3 * idx)
        plt.imshow(unnormalize(c).permute(1, 2, 0))
        plt.subplot(10, 3, 2 + 3 * idx)
        plt.imshow(g.permute(1, 2, 0), 'gray')
        generated = unet(g.to(device=DEVICE).unsqueeze(0)).detach().squeeze(0)
        generated = unnormalize(generated)
        plt.subplot(10, 3, 3 + 3 * idx)
        plt.imshow(generated.squeeze(0).cpu().permute(1, 2, 0))
    plt.show()
    break