In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
DATAPATH = './models/'
# GPU training
DEVICE = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
print(f"Training on device {DEVICE}.")

Training on device cuda.


In [2]:
import datetime as dt

def cgan_training_loop(n_epochs, optimizers, generator, discriminator, train_loader, valid_loader):
    best_valid_loss = float('inf')
    # We alternate between one gradient descent step on D, then one step on G.
    for epoch in range(1, n_epochs + 1):
        g_loss = 0.0
        d_loss = 0.0
        generator.train(mode=True)
        discriminator.train(mode=True)
        for colored, grayscale in train_loader:
            colored = colored.to(device=DEVICE)
            grayscale = grayscale.to(device=DEVICE)
            # Ascend D gradient
            # Freeze G weights
            for p in generator.parameters():
                p.requires_grad = False
            g_output = generator(grayscale)
            d_output_real = discriminator(colored, grayscale)
            d_output_fake = discriminator(g_output, grayscale)
            loss = CGANDiscriminatorLoss(d_output_real, d_output_fake)
            optimizers[1].zero_grad()
            loss.backward()
            optimizers[1].step()
            d_loss += loss.item()
            # Unfreeze G parameters
            for p in generator.parameters():
                p.requires_grad = True
            # Descend G gradient
            # Freeze D weights
            for p in discriminator.parameters():
                p.requires_grad = False
            # Same sample as in D gradient?
            g_output = generator(grayscale)
            d_output_fake = discriminator(g_output, grayscale)
            loss = CGANGeneratorLoss(d_output_fake) + torch.nn.L1loss(colored, g_output)
            optimizers[0].zero_grad()
            loss.backward()
            optimizers[0].step()
            g_loss += loss.item()
            # Unfreeze D weights
            for p in discriminator.parameters():
                p.requires_grad = True
            del g_output, d_output_real, d_output_fake
        # Validation
        with torch.no_grad():
            generator.train(mode=False)
            discriminator.train(mode=False)
            valid_loss = 0.0
            for colored, grayscale in valid_loader:
                colored = colored.to(device=DEVICE)
                grayscale = grayscale.to(device=DEVICE)
                g_output = generator(grayscale)
                d_output_real = discriminator(colored, grayscale)
                d_output_fake = discriminator(g_output, grayscale)
                valid_loss += torch.mean(torch.log(d_output_real) + torch.log(1 - d_output_fake)) +\
                             torch.nn.L1loss(colored, g_output)
                del g_output, d_output_real, d_output_fake
        # Verbose training
        if epoch == 1 or epoch % 10 == 0:
            print('{} Epoch {}, Training {}, Valid {}'.format(dt.datetime.now(), epoch,
                                                         loss_train / len(train_loader),
                                                         valid_loss / len(valid_loader)))
        # Save best model so far
        if valid_loss < best_valid_loss - 0.5:
            torch.save(generator.state_dict(), DATA_PATH + 'G_{:2.2f}.pt'.format(valid_loss / len(valid_loader)))
            best_valid_loss = valid_loss
    del d_loss, g_loss

In [3]:
from cocoLoader import load_coco_dataset
from UNet import UNet
from PatchGAN import PatchGAN
from cGANloss import CGANDiscriminatorLoss, CGANGeneratorLoss

train_dataloader, valid_dataloader = load_coco_dataset(batch_size=130)
unet = UNet(in_channels=1, out_channels=3).to(device=DEVICE) #Implement U-Net generator
patchgan = PatchGAN().to(device=DEVICE)
g_optimizer = torch.optim.Adam(unet.parameters(), lr=1e-2)
d_optimizer = torch.optim.Adam(patchgan.parameters(), lr=1e-2)
cgan_training_loop(
    n_epochs = 100,
    optimizers = (g_optimizer, d_optimizer),
    generator = unet,
    discriminator = patchgan,
    train_loader = train_dataloader,
    valid_loader= valid_dataloader)
torch.cuda.empty_cache()

TypeError: forward() takes 2 positional arguments but 3 were given

In [None]:
# Load best models
fd = 0
unet.load_state_dict(torch.load(DATAPATH + 'G_{:2.2f}.pt'.format(fd), map_location=DEVICE))