In [None]:
import os
import math
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pytorch_msssim import MS_SSIM
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import matplotlib.pyplot as plt

In [None]:
#device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
device = "cpu"
# TODO: make sure to .to(device) the class later, and also set up gpu

# path to font list
fonts_csv = "fonts.csv"
# root directory for dataset
dataroot = "images"
# number of workers for dataloader
workers = 0
# number of epochs
num_epochs = 500
# batch size for training
batch_size = 16
# height and width of input image
img_size = 64
# the alphabet characters
alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
# number of channels
nc0 = 1 * len(alphabet)
nc1 = 4 * len(alphabet)
nc2 = 8 * len(alphabet)
nc3 = 16 * len(alphabet)
# disciminator channels
dc0 = 1
dc1 = 8
dc2 = 16
dc3 = 32
# threshold
thresh = 0
# learning rate
lr = 0.002
# beta1 for Adam
beta1 = 0.5
# real label
real_label = 1.0
# fake label
fake_label = 0.0
# number of extra times to run the discriminator than the encdec per epoch
num_dis = 1
# coefficient of the discriminator loss in training
cof_dis = 0
# number of patches to sample
num_patches = 3
# letter we use to generate all the other letters
base_letter = 'R'
# letter we are trying to generate
gen_letter = 'B'

In [None]:
class FontDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.fontlist = pd.read_csv(csv_file, sep=' ')
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.fontlist)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = {}
        for c in alphabet:
            path = os.path.join(self.root_dir, c, f'{idx}.npy')
            img = np.load(path)
            img = img[img_size//2:img_size//2 + img_size, img_size//2:img_size//2 + img_size, :]
            img = self.transform(img)
            sample[c] = img

        return sample

In [None]:
# From https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/3
def get_gaussian_kernel(kernel_size=3, sigma=2, channels=3, padding=1):
    # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
    x_coord = torch.arange(kernel_size)
    x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
    y_grid = x_grid.t()
    xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()

    mean = (kernel_size - 1)/2.
    variance = sigma**2.

    # Calculate the 2-dimensional gaussian kernel which is
    # the product of two gaussian distributions for two different
    # variables (in this case called x and y)
    gaussian_kernel = (1./(2.*math.pi*variance)) *\
                      torch.exp(
                          -torch.sum((xy_grid - mean)**2., dim=-1) /\
                          (2*variance)
                      )

    # Make sure sum of values in gaussian kernel equals 1.
    gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)

    # Reshape to 2d depthwise convolutional weight
    gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
    gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1)

    gaussian_filter = nn.Conv2d(in_channels=channels, out_channels=channels,
                                kernel_size=kernel_size, groups=channels, bias=False, padding=padding)

    gaussian_filter.weight.data = gaussian_kernel
    gaussian_filter.weight.requires_grad = False
    
    return gaussian_filter

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self):
        super(EncoderDecoder, self).__init__()
        self.conv01 = nn.Conv2d(dc0, dc1, 3, padding=1)
        self.conv12 = nn.Conv2d(dc1, dc2, 3, padding=1)
        self.conv23 = nn.Conv2d(dc2, dc3, 3, padding=1)
        
        self.conv33 = nn.Conv2d(dc3, dc3, 3, padding=1)
        self.conv32 = nn.Conv2d(dc3, nc0, 3, padding=1)
        
        self.conv0same = nn.Conv2d(nc0, nc0, 3, padding=1)
    
        self.pool = nn.MaxPool2d(2, return_indices=True)
        self.unpool = nn.MaxUnpool2d(2)
        self.relu = nn.ReLU(inplace=True)
        self.leakyrelu = nn.LeakyReLU(0.2, inplace=True)
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.threshold = nn.Threshold(thresh, 0)
        
        self.gaussian_filter = get_gaussian_kernel(kernel_size = 3, sigma=2, channels=nc0)
        
    def forward(self, x):
        x = self.conv01(x)
        x = self.leakyrelu(x)
        x = self.conv12(x)
        x = self.leakyrelu(x)
        x = self.conv23(x)
        x = self.leakyrelu(x)
        x, idx1 = self.pool(x)
        
        x = self.conv33(x)
        x = self.leakyrelu(x)
        x = self.conv33(x)
        x = self.leakyrelu(x)
        x, idx2 = self.pool(x)
        
        x = self.conv33(x)
        x = self.leakyrelu(x)
        x = self.conv33(x)
        x = self.leakyrelu(x)
        x, idx3 = self.pool(x)
        
        x = self.unpool(x, idx3)
        x = self.conv33(x)
        x = self.leakyrelu(x)
        x = self.conv33(x)
        x = self.leakyrelu(x)
        
        x = self.unpool(x, idx2)
        x = self.conv33(x)
        x = self.leakyrelu(x)
        x = self.conv33(x)
        x = self.leakyrelu(x)
        
        x = self.unpool(x, idx1)
        x = self.conv32(x)
        x = self.leakyrelu(x)
        
        x = self.gaussian_filter(x)
        x = self.conv0same(x)
        x = self.leakyrelu(x)
        x = self.conv0same(x)
        
        x = self.tanh(x)
        
        return x

In [None]:
def concat_tensors(data):
    tensor = None
    
    for c in alphabet:
        if tensor == None:
            tensor = data[c]
        else:
            tensor = torch.cat((tensor, data[c]), 1)
    
    return tensor

In [None]:
dataset = FontDataset(csv_file=fonts_csv, 
    root_dir=dataroot, 
    transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5),
#AddGaussianNoise(0., 0.05),
    ]))
testset_size = len(dataset) // 5 * 4
train_set, val_set = random_split(dataset, [testset_size, len(dataset) - testset_size], generator=torch.Generator().manual_seed(42))
train_data = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=workers)
val_data = DataLoader(val_set, batch_size=1, shuffle=True, num_workers=workers)

In [None]:
# path to font list
fonts_csv = "fonts.csv"
# root directory for dataset
dataroot = "images"
# number of workers for dataloader
workers = 0

# model that we are evaluating
model_file = 'encdec.pt'

In [None]:
encdec = EncoderDecoder()
encdec.load_state_dict(torch.load(model_file, map_location=torch.device('cpu')))

# encdec = Finetune_EncDec(pretrained=encdec)
# encdec.load_state_dict(torch.load('encdec-finetune.pt'))

# for param in encdec.parameters():
#     print(param.data)

encdec.eval()

In [None]:
#look at result and ground truth side by side
for i, data in enumerate(val_data):
    fig_letter = 'X'
    fig = plt.figure(figsize=(8, 8))
    output = encdec(data['R'])
    fig.add_subplot(1, 2, 1)
    plt.imshow(data[fig_letter][0].permute(1, 2, 0).detach().numpy(), cmap='gray')
    fig.add_subplot(1, 2, 2)
    plt.imshow(output[0,ord(fig_letter) - ord('A')].detach().numpy(), cmap='gray')
    plt.show()

In [None]:
for i, data in enumerate(val_data):
    if i > 10:
        break
    n = len(alphabet)
    output = encdec(data['R'])
    fig = plt.figure(figsize=(16,8*n))
    fig.subplots_adjust(wspace=0, hspace=0)
    
    for j, c in enumerate(alphabet):
        # ground truth
        ax = fig.add_subplot(1, n, j + 1)
        ax.set_xticks([])
        ax.set_yticks([])
        plt.imshow(data[c][0].permute(1,2,0).detach().numpy(), cmap='gray')
    plt.show()
    
    fig = plt.figure(figsize=(16,8*n))
    fig.subplots_adjust(wspace=0, hspace=0)
    for j, c in enumerate(alphabet):
        # generated
        ax = fig.add_subplot(1, n, j + 1)
        ax.set_xticks([])
        ax.set_yticks([])
        plt.imshow(output[0,ord(c) - ord('A')].detach().numpy(), cmap='gray')
    plt.show()
    
#     plt.savefig('all-alphabet.png')
    

In [None]:
#set up for evaluation
criterion_L1 = nn.L1Loss()
criterion_ssim = MS_SSIM(win_size=3, data_range=1, size_average=True, channel=1)

In [None]:
#evaluate accuracy using L1 loss
val_len = len(val_data)
sum_L1 = torch.zeros((len(alphabet), val_len))
sum_ssim = torch.zeros((len(alphabet), val_len))
for i, data in enumerate(val_data):
    if i > 500:
        break
    print(i)
    output = encdec(data['R'])
    
    for c in alphabet:
        truth = data[c].reshape(1, 1, img_size, img_size)
        index = ord(c) - ord('A')
        generated = output[0, index].reshape(1,1, img_size, img_size)
        
        loss_L1 = criterion_L1(truth, generated)

        output_norm = (generated + 1) / 2
        truth_norm = (truth+1) / 2
        
        loss_ssim = 1 - criterion_ssim(output_norm, truth_norm)
        
        sum_L1[index][i] += loss_L1
        sum_ssim[index][i] += loss_ssim
        
print('done')

In [None]:
l1 = sum_L1.detach().numpy()
ssim = sum_ssim.detach().numpy() 

# average by alphabet
print('average l1 loss by alphabet', np.mean(l1, axis=1))
print('average ssim loss by alphabet', np.mean(ssim, axis=1))