In [None]:
import torch
from models import *
from data import IAMDataset
import torch.optim as optim
import torchvision.utils as vutils
from matplotlib import pyplot as plt
import torch.nn.functional as F
import numpy as np

%matplotlib inline

In [None]:
# Root directory for dataset
dataroot = "../../datasets/iam"
# Number of workers for dataloader
workers = 5
# Batch size during training
batch_size = 8
# Spatial size of training images.
imsize=(64,640)
# Number of training epochs
num_epochs = 1
# Max word len
min_len = 2
max_len=10
# Chars and <end>
vocab_size=27

# Create the dataset
dataset = IAMDataset(data_path=dataroot, imsize=imsize, min_len=min_len, max_len=max_len)
dataset_words = dataset.create_word_dataset()

# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, num_workers=workers)
dataloader_words = torch.utils.data.DataLoader(dataset_words, batch_size, shuffle=True, num_workers=workers)

# Initialize Generator H
generator = GANwritingGenerator(imsize, max_len).cuda()

# Initialize Discriminators
discriminator = Discriminator(num_classes=1, imsize=imsize).cuda()
writer_classifier = Discriminator(num_classes=dataset.num_writers, imsize=imsize).cuda()
word_recognizer = Seq2Seq(imsize=imsize, max_len=max_len+1, vocab_size=vocab_size).cuda()

# Initialize BCELoss function
criterion_binary = nn.BCELoss()
criterion = nn.CrossEntropyLoss()
criterion_KLDiv = nn.KLDivLoss(reduction='mean')

# Establish convention for real and fake labels during training
REAL = 1
FAKE = 0

# Setup Adam optimizers for both D and W
optimizerD = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizerWC = optim.Adam(writer_classifier.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizerWR = optim.Adam(word_recognizer.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizerG = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))

# Turn on/off training stages
train_D = True
train_WC = True
train_WR = True
train_G = True

# Init output tensors
lossD = torch.tensor([0])
lossWC = torch.tensor([0])
lossWR = torch.tensor([0])
lossG = torch.tensor([0])
D_x = torch.tensor([0])
D_G_z1 = torch.tensor([0])
D_G_z2 = torch.tensor([0])

In [None]:
# Training Loop

# Lists to keep track of progress
img_list = []
D_losses = []
WC_losses = []
WR_losses = []
G_losses = []
iters = 0

# for Seq2Seq
src_len = imsize[1] * torch.ones(batch_size)

# fixed input to track progress
styles_fixed, _, _ = dataset[500]
styles_fixed = styles_fixed[None,...]
# content_fixed = dataset_words[5][None,...]
content_fixed = torch.zeros(1, max_len, dtype=int)
content_fixed[0,:6] = torch.tensor([list(map(lambda x: ord(x) - ord('a'), "vision"))])


print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, ((styles, content_real, writer_id), content_new) in enumerate(zip(dataloader, dataloader_words)):
        styles = styles.cuda() # sample - 15 images concatenated channel-wise
        images_real = styles[:,0:1,:,:] # take 1 image per sample
        writer_id = writer_id.cuda()
        content_real = content_real[:,0,:].cuda()
        content_new = content_new.cuda()
        content_real_one_hot = F.one_hot(content_real, num_classes=vocab_size).float()
        content_new_one_hot = F.one_hot(content_new, num_classes=vocab_size).float()
        
        # Generate fake images
        images_fake = generator(styles, content_new)
        
        ############################
        # Update D networks: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################

        # (1) Discriminative loss on fake&real images (discriminator)
        if train_D:
            discriminator.zero_grad()

            outD_real = discriminator(images_real).view(-1)
            label = torch.full((batch_size,), REAL, dtype=torch.float, device="cuda")
            loss_real = criterion_binary(outD_real, label)

            outD_fake = discriminator(images_fake.detach()).view(-1)
            label = torch.full((batch_size,), FAKE, dtype=torch.float, device="cuda")
            loss_fake = criterion_binary(outD_fake, label)

            lossD = loss_fake + loss_real
            lossD.backward()
            optimizerD.step() # Update D

            D_x = outD_real.mean().item()
            D_G_z1 = outD_fake.mean().item()

        # (2) Style loss on real data (writer classifier)
        if train_WC:
            writer_classifier.zero_grad()

            outWC_real = writer_classifier(images_real)
            lossWC = criterion(outWC_real, writer_id)
            lossWC.backward()
            optimizerWC.step() # Update WC

        # (3) Content loss on real data (word recognizer)
        if train_WR:
            word_recognizer.zero_grad()

            outWR_real, _ = word_recognizer(images_real, content_real_one_hot, src_len)
            outWR_real = outWR_real.transpose(0,1)
            lossWR = criterion_KLDiv(F.log_softmax(outWR_real, dim=2), content_real_one_hot)
            lossWR.backward()
            optimizerWR.step()
        
        ############################
        # Update G network: maximize log(D(G(z)))
        ###########################
        
        if train_G:
            # Complete loss on fake data
            generator.zero_grad()

            # Since we just updated D, perform another forward passes
            outD_fake = discriminator(images_fake).view(-1)
            # fake labels are real for generator cost:
            label = torch.full((batch_size,), REAL, dtype=torch.float, device="cuda") 
            lossG = criterion_binary(outD_fake, label)

            if train_WC:
                outWC_fake = writer_classifier(images_fake)
                lossG_WC = criterion(outWC_fake, writer_id)
                lossG = lossG + lossG_WC

            if train_WR:
                outWR_fake, _ = word_recognizer(images_fake, content_new_one_hot, src_len)
                outWR_fake = outWR_fake.transpose(0,1)
                lossG_WR = criterion_KLDiv(F.log_softmax(outWR_fake, dim=2), content_new_one_hot)
                lossG = lossG + lossG_WR

            lossG.backward()
            optimizerG.step() # Update G

            D_G_z2 = outD_fake.mean().item()

        # Output training stats
        if i % 10 == 0:
            print('[%d/%d][%d/%d]\tLoss[D,WC,WR]: %.4f\t%.4f\t%.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     lossD.item(), lossWC.item(), lossWR.item(), lossG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        D_losses.append(lossD.item())
        WC_losses.append(lossWC.item())
        WR_losses.append(lossWR.item())
        G_losses.append(lossG.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 50 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                generator.eval()
                fake = generator(styles_fixed.cuda(), content_fixed.cuda()).detach().cpu()[0].permute(1,2,0)
                generator.train()
            img_list.append(fake)

        iters += 1

In [None]:
# # Plot losses
# plt.figure(figsize=(16,16))
# plt.title("Test Images")
# plt.plot(D_losses, label='D')
# plt.plot(WC_losses, label='WC')
# plt.plot(WR_losses, label='WR')
# plt.plot(G_losses, label='G')
# plt.yscale('log')
# plt.legend()
# plt.show()

In [None]:
def normalize(x):
    return (x-x.min())/(x.max()-x.min())

In [None]:
# # Plot some generated images

# plt.figure(figsize=(16,160))
# plt.axis("off")
# plt.title("Test Images")
# # plt.imshow(torch.cat([normalize(x) for x in img_list], dim=0).repeat(1,1,3))
# I = np.transpose(torch.stack([normalize(x) for x in img_list]), (0,3,1,2))
# I = np.transpose(vutils.make_grid(I, nrow=1, padding=4, normalize=True).cpu(),(1,2,0))
# plt.imshow(I)
# plt.show()

In [None]:
# # Plot some training images

# real_batch, _, _ = next(iter(dataloader))
# plt.figure(figsize=(16,16))
# plt.axis("off")
# plt.title("Training Images")
# plt.imshow(np.transpose(vutils.make_grid(real_batch[:,0:1,:,:].repeat(1,3,1,1), nrow=1, padding=4, normalize=True).cpu(),(1,2,0)))
# plt.show()

In [None]:
# # Save images

# for i,im in enumerate(img_list):
#     A = normalize(im.repeat(1,1,3).numpy())
#     plt.figure(figsize=(12,12))
#     plt.axis("off")
#     plt.imshow(A)
#     plt.savefig("./res/" + str(i) + ".jpg")