In [1]:
## Libraries
import glob
from multiprocessing import cpu_count
import os
import sys
import time

## 3rd party
from gensim.models import Word2Vec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image, ImageFilter
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

_path = ".."
if _path not in sys.path:
    sys.path.append(_path)
from lib.dataset import TextArtDataLoader, AlignCollate, ImageBatchSampler
from lib.config import Config
from lib.arch import GeneratorResNet, DiscriminatorStackGAN1
from lib.utils import GANLoss

%load_ext autoreload
%autoreload 2
%reload_ext autoreload

In [2]:
BATCH_SIZE = 2
# N_WORKERS = cpu_count() - 1
N_WORKERS = 0
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
N_EPOCHS = 10
LR_G = 1e-4
LR_D = 1e-4
WEIGHT_DECAY = 1e-4

CONFIG = Config()

In [None]:
train_dataset = TextArtDataLoader(CONFIG, mode='train')
val_dataset = TextArtDataLoader(CONFIG, mode='val')
test_dataset = TextArtDataLoader(CONFIG, mode='test')

train_align_collate = AlignCollate(CONFIG, 'train')
val_align_collate = AlignCollate(CONFIG, 'val')

train_batch_sampler = ImageBatchSampler(CONFIG, mode='train')
val_batch_sampler = ImageBatchSampler(CONFIG, mode='val')

train_loader = DataLoader(train_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=False,
                          num_workers=N_WORKERS,
                          pin_memory=True,
                          collate_fn=train_align_collate,
                          sampler=train_batch_sampler,
                         )
val_loader = DataLoader(val_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=False,
                          num_workers=N_WORKERS,
                          pin_memory=True,
                          collate_fn=val_align_collate,
                          sampler=val_batch_sampler
                         )
test_loader = DataLoader(test_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=False,
                          num_workers=N_WORKERS,
                          pin_memory=True,
                          collate_fn=None,
                         )

In [None]:
for i, (image, wv_tensor, fake_wv_tensor) in enumerate(train_loader):
    print("IMAGE:", image.shape)
    print("WV:", wv_tensor.shape)
    print("Fake WV:", fake_wv_tensor.shape)
    
    print("WVs")
    for wvs in wv_tensor:
        for wv in wvs:
            wv = np.array(wv)
            word, prob = train_loader.dataset.word2vec_model.wv.similar_by_vector(wv)[0]
            print("{}/{:.3f}".format(word, prob), end=' ')
        print()
    
    print("\nFake WVs")
    for fake_wvs in fake_wv_tensor:
        for fake_wv in fake_wvs:
            fake_wv = np.array(fake_wv)
            fake_word, prob = train_loader.dataset.word2vec_model.wv.similar_by_vector(fake_wv)[0]
            print("{}/{:.3f}".format(fake_word, prob), end=' ')
        print()
    
    images = np.array(image)
    for img in images:
        img = img.transpose(1, 2, 0)
        plt.figure(figsize=(6, 6))
        plt.imshow(img)
        plt.show()
    break

In [None]:
wv_tensors

In [None]:
images = torch.Tensor(images)
images.shape

In [None]:
r = wv_tensor.reshape(images.shape)
r.shape

In [None]:
torch.cat((images, r), 1).shape

In [3]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

## Generator

In [4]:
G = GeneratorResNet(CONFIG).to(DEVICE)
G.apply(weights_init)

GeneratorResNet(
  (fc): Sequential(
    (0): Linear(in_features=4096, out_features=2048, bias=False)
    (1): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (model): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), bias=False)
    (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU(inplace)
    (4): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace)
    (7): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (8): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace)
    (10): ResnetBlock(
      (conv_block): Sequential(
        (0): ReflectionPad2d((1, 1, 1, 1))
        (1): Conv2d(512, 512, 

In [6]:
image = torch.Tensor(2, 4096).to(DEVICE)
G(image)

RuntimeError: Padding size should be less than the corresponding input dimension, but got: padding (1, 1) at dimension 3 of input [2, 512, 1, 1]

## Discriminator

In [None]:
D = Discriminator(CONFIG).to(DEVICE)
D.apply(weights_init)

## Loss and optimizers

In [None]:
data_loader = train_loader

optimizer_d = torch.optim.Adam(D.parameters(), lr=LR_D, weight_decay=WEIGHT_DECAY)
optimizer_g = torch.optim.Adam(G.parameters(), lr=LR_G, weight_decay=WEIGHT_DECAY)

loss = nn.BCELoss().to(DEVICE)
loss_ls = lambda x, y: 0.5 * torch.mean((x - y) ** 2)
loss_ms = nn.MSELoss().to(DEVICE)

In [None]:
G.train()
D.train()

for epoch in range(N_EPOCHS):
    epoch_start = time.time()
    total_g_loss = 0.0
    total_d_loss = 0.0
    
    for i, (images, word_vectors_tensor) in enumerate(data_loader):

        batch_size = images.size()[0]
        
        real_label = torch.full((batch_size,), 1.0, device=DEVICE)
        fake_label = torch.full((batch_size,), 0.0, device=DEVICE)

        images = images.to(DEVICE)
        word_vectors_tensor = word_vectors_tensor.to(DEVICE)
        word_sequence = word_vectors_tensor[:, 0, :].unsqueeze(2).unsqueeze(3)
        
        # Discriminator pass for real
        D.zero_grad()
        output_real = D(images).view(-1)
        loss_real = loss(output_real, real_label)
        loss_real.backward(retain_graph=False)
        
        # Discriminator pass for fake
        fake = G(word_sequence)
        output_fake = D(fake.detach()).view(-1)
        loss_fake = loss(output_fake, fake_label)
        loss_fake.backward(retain_graph=False)
        loss_d = loss_real + loss_fake
        
        # Discriminator update
        optimizer_d.zero_grad()
        optimizer_d.step()

        # Generator pass
        G.zero_grad()
        output_fake = D(fake).view(-1)
        loss_g = loss(output_fake, real_label)
        loss_g.backward(retain_graph=False)

        # Generator backward pass
        optimizer_g.zero_grad()
        optimizer_g.step()
        
        # Update total loss
        total_g_loss += loss_g.item()
        total_d_loss += loss_d.item()

        # Print logs
        if i % 20 == 0:
            print('[{0:3d}/{1}] {2:3d}/{3} loss_g: {4:.4f} | loss_d: {5:4f}'
                .format(epoch + 1, N_EPOCHS, i + 1, len(data_loader), loss_g.item(), loss_d.item()))
            
    print("Epoch time: {}".format(time.time() - epoch_start))
            
    break
            
    
#     # Save your model weights
#     if (epoch + 1) % 5 == 0:
#         save_dict = {
#             'g':G.state_dict(), 
#             'g_optim':optimizer_g.state_dict(),
#             'd': D.state_dict(),
#             'd_optim': optimizer_d.state_dict()
#         }
#         torch.save(save_dict, os.path.join(MODEL_PATH, 'checkpoint_{}.pth'.format(epoch + 1)))
        
#     # Merge noisy input, ground truth and network output so that you can compare your results side by side
#     out = torch.cat([img, fake], dim=2).detach().cpu().clamp(0.0, 1.0)
#     vutils.save_image(out, os.path.join(OUTPUT_PATH, "{}_{}.png".format(epoch, i)), normalize=True)
    
#     # Calculate avarage loss for the current epoch
#     avg_g_loss = total_g_loss / len(data_loader)
#     avg_d_loss = total_d_loss / len(data_loader)
#     print('Epoch[{}] Training Loss G: {:4f} | D: {:4f}'.format(epoch + 1, avg_g_loss, avg_d_loss))
    
#     cache_train_g.append(avg_g_loss)
#     cache_train_d.append(avg_d_loss)

In [45]:
pred = torch.Tensor(2, 1, 12, 12)
target = torch.Tensor([1.0]).expand_as(pred)

In [49]:
s = (torch.sigmoid(pred) > 0.5).float()

In [27]:
target

tensor([1., 1., 1., 1.])

In [55]:
torch.mean((s == target).float())

tensor(0.0868)

In [41]:
np.mean(np.array(s == target), axis=0)

0.5

In [42]:
np.mean?