# font2img
- ttf 파일의 폰트를 가져와 이미지로 바꾸는 작업

In [1]:
import argparse
import sys
import glob
import numpy as np
import io, os
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
import collections

In [2]:
SRC_PATH = './fonts/source/'
TRG_PATH = './fonts/target/'
OUTPUT_PATH = './dataset_png/'

In [3]:
def draw_single_char(ch, font, canvas_size):
    image = Image.new('L', (canvas_size, canvas_size), color = 255)
    drawing = ImageDraw.Draw(image)
    w, h = drawing.textsize(ch, font=font)
    drawing.text(
        ((canvas_size-w)/2, (canvas_size-h)/2),
        ch,
        fill=(0),
        font=font
    )
    flag = np.sum(np.array(image))
    
    #해당 font에 글자가 없으면 return None
    if flag == 255 * 128 * 128:
        return None
    
    return image

In [4]:
def draw_example(ch, src_font, dst_font, canvas_size):
    dst_img = draw_single_char(ch, dst_font, canvas_size)
    
    #해당 font에 글자가 없으면 return None
    if not dst_img:
        return None
    
    src_img = draw_single_char(ch, src_font, canvas_size)
    example_img = Image.new("RGB", (canvas_size*2,canvas_size), (255,255,255)).convert('L')
    example_img.paste(dst_img, (0,0))
    example_img.paste(src_img, (canvas_size,0))
    return example_img

In [5]:
def draw_handwriting(ch, src_font, canvas_size, dst_folder, label, count):
    dst_path = dst_folder + "%d_%04d" % (label, count) + ".png"
    dst_img = Image.open(dst_path)
    src_img = draw_single_char(ch, src_font, canvas_size)
    example_img = Image.new("RGB", (canvas_size * 2, canvas_size), (255, 255, 255)).convert('L')
    example_img.paste(dst_img, (0, 0))
    example_img.paste(src_img, (canvas_size, 0))
    return example_img


# Package
- img to pickle

In [6]:
from __future__ import print_function
from __future__ import absolute_import

import argparse
import glob
import os
import pickle as pickle
import random

In [7]:
def pickle_examples(from_dir, train_path, val_path, train_val_split=0.2, with_charid=False):
    """
    Compile a list of examples into pickled format, so during
    the training, all io will happen in memory
    """
    paths = glob.glob(os.path.join(from_dir, "*.png"))
    with open(train_path, 'wb') as ft:
        with open(val_path, 'wb') as fv:
            print('all data num:', len(paths))
            c = 1
            val_count = 0
            train_count = 0
            if with_charid:
                print('pickle with charid')
                for p in paths:
                    c += 1
                    label = int(os.path.basename(p).split("_")[0])#font
                    charid = int(os.path.basename(p).split("_")[1].split(".")[0])#가나다
                    with open(p, 'rb') as f:
                        img_bytes = f.read()
                        example = (label, charid, img_bytes)
                        r = random.random()
                        if r < train_val_split:
                            pickle.dump(example, fv)
                            val_count += 1
                            if val_count % 10000 == 0:
                                print("%d imgs saved in val.obj" % val_count)
                        else:
                            pickle.dump(example, ft)
                            train_count += 1
                            if train_count % 10000 == 0:
                                print("%d imgs saved in train.obj" % train_count)
                print("%d imgs saved in val.obj, end" % val_count)
                print("%d imgs saved in train.obj, end" % train_count)
            else:
                for p in paths:
                    c += 1
                    label = int(os.path.basename(p).split("_")[0])
                    with open(p, 'rb') as f:
                        img_bytes = f.read()
                        example = (label, img_bytes)
                        r = random.random()
                        if r < train_val_split:
                            pickle.dump(example, fv)
                            val_count += 1
                            if val_count % 10000 == 0:
                                print("%d imgs saved in val.obj" % val_count)
                        else:
                            pickle.dump(example, ft)
                            train_count += 1
                            if train_count % 10000 == 0:
                                print("%d imgs saved in train.obj" % train_count)
                print("%d imgs saved in val.obj, end" % val_count)
                print("%d imgs saved in train.obj, end" % train_count)
            return

In [8]:
def pickle_interpolation_data(from_dir, save_path, char_ids, font_filter):
    paths = glob.glob(os.path.join(from_dir, "*.png"))
    with open(save_path, 'wb') as ft:
        c = 0
        for p in paths:
            charid = int(p.split('/')[-1].split('.')[0].split('_')[1])
            label = int(os.path.basename(p).split("_")[0])
            if (charid in char_ids) and (label in font_filter):
                c += 1
                with open(p, 'rb') as f:
                    img_bytes = f.read()
                    example = (label, charid, img_bytes)
                    pickle.dump(example, ft)
        print('data num:', c)
        return

# Function
- deep learning functions : conv2d, relu etc.

In [9]:
from __future__ import print_function
from __future__ import absolute_import
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset

In [10]:
def batch_norm(c_out, momentum=0.1):
    return nn.BatchNorm2d(c_out, momentum=momentum)

In [11]:
def conv2d(c_in, c_out, k_size, stride=2, pad=1, dilation=1, bn=True, lrelu=True, leak=0.2):
    layers =[]
    if lrelu:
        layers.append(nn.LeakyReLU(leak))
    layers.append(nn.Conv2d(c_in, c_out, k_size, stride, pad))
    if bn:
        layers.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*layers)

In [12]:
def deconv2d(c_in, c_out, k_size=3, stride=1, pad=1, dilation=1, bn=True, dropout=False, p=0.5):
    layers = []
    layers.append(nn.LeakyReLU(0.2))
    layers.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad))
    if bn:
        layers.append(nn.BatchNorm2d(c_out))
    if dropout:
        layers.append(nn.Dropout(p))
    return nn.Sequential(*layers)

In [13]:
def lrelu(leak=0.2):
    return nn.LeakyReLU(leak)

In [14]:
def dropout(p=0.2):
    return nn.Dropout(p)

In [15]:
def fc(input_size, output_size):
    return nn.Linear(input_size, output_size)

In [16]:
def init_embedding(embedding_num, embedding_dim, stddev=0.01):
    embedding = torch.randn(embedding_num, embedding_dim) * stddev
    embedding = embedding.reshape((embedding_num, 1, 1, embedding_dim))
    return embedding

In [17]:
def embedding_lookup(embeddings, embedding_ids, GPU=False):
    batch_size = len(embedding_ids)
    embedding_dim = embeddings.shape[3]
    local_embeddings = []
    for id_ in embedding_ids:
        if GPU:
            local_embeddings.append(embeddings[id_].cpu().numpy())
        else:
            local_embeddings.append(embeddings[id_].data.numpy())
    local_embeddings = torch.from_numpy(np.array(local_embeddings))
    if GPU:
        local_embeddings = local_embeddings.cuda()
    local_embeddings = local_embeddings.reshape(batch_size, embedding_dim, 1, 1)
    return local_embeddings

In [18]:
def interpolated_embedding_lookup(embeddings, interpolated_embedding_ids, grid):
    batch_size = len(interpolated_embedding_ids)
    interpolated_embeddings = []
    embedding_dim = embeddings.shape[3]

    for id_ in interpolated_embedding_ids:
        interpolated_embeddings.append((embeddings[id_[0]] * (1 - grid) + embeddings[id_[1]] * grid).cpu().numpy())
    interpolated_embeddings = torch.from_numpy(np.array(interpolated_embeddings)).cuda()
    interpolated_embeddings = interpolated_embeddings.reshape(batch_size, embedding_dim, 1, 1)
    return interpolated_embeddings

# Utils

- data pre-processing etc.

In [19]:
from __future__ import print_function
from __future__ import absolute_import

import os
import glob

import imageio
import scipy.misc as misc
import numpy as np
from io import BytesIO
from PIL import Image
#from scipy.misc import imresize
import matplotlib.pyplot as plt

In [20]:
def pad_seq(seq, batch_size):
    # pad the sequence to be the multiples of batch_size
    seq_len = len(seq)
    if seq_len % batch_size == 0:
        return seq
    padded = batch_size - (seq_len % batch_size)
    seq.extend(seq[:padded])
    return seq

In [21]:
def bytes_to_file(bytes_img):
    return BytesIO(bytes_img)

In [22]:
def normalize_image(img):
    """
    Make image zero centered and in between (-1, 1)
    """
    normalized = (img / 127.5) - 1.
    return normalized

In [23]:
def denorm_image(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

In [24]:
def read_split_image(img):
    mat = misc.imread(img).astype(np.float)
    side = int(mat.shape[1] / 2)
    assert side * 2 == mat.shape[1]
    img_A = mat[:, :side]  # target
    img_B = mat[:, side:]  # source

    return img_A, img_B

In [25]:
def shift_and_resize_image(img, shift_x, shift_y, nw, nh):
    w, h = img.shape
    #enlarged = misc.imresize(img, [nw, nh])
    enlarged = Image.fromarray(img).resize(size=(nh, nw))
    return enlarged[shift_x:shift_x + w, shift_y:shift_y + h]

In [26]:
def scale_back(images):
    return (images + 1.) / 2.

In [27]:
def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j * h:j * h + h, i * w:i * w + w, :] = image

    return img

In [28]:
def save_concat_images(imgs, img_path):
    concated = np.concatenate(imgs, axis=1)
    misc.imsave(img_path, concated)

In [29]:
def save_gif(gif_path, image_path, file_name):
    filenames = sorted(glob.glob(os.path.join(image_path, "*.png")))
    images = []
    for filename in filenames:
        images.append(imageio.imread(filename))
    imageio.mimsave(os.path.join(gif_path, file_name), images)

In [30]:
def show_comparison(font_num, real_targets, fake_targets, show_num=8):
    plt.figure(figsize=(14, show_num//2+1))
    for idx in range(show_num):
        plt.subplot(show_num//4, 8, 2*idx+1)
        plt.imshow(real_targets[font_num][idx].reshape(128, 128), cmap='gray')
        plt.title("Real [%d]" % font_num)
        plt.axis('off')

        plt.subplot(show_num//4, 8, 2*idx+2)
        plt.imshow(fake_targets[font_num][idx].reshape(128, 128), cmap='gray')
        plt.title("Fake [%d]" % font_num)
        plt.axis('off')
    plt.show()

In [31]:
def tight_crop_image(img, verbose=False, resize_fix=False):
    img_size = img.shape[0]
    full_white = img_size
    col_sum = np.where(full_white - np.sum(img, axis=0) > 1)
    row_sum = np.where(full_white - np.sum(img, axis=1) > 1)
    y1, y2 = row_sum[0][0], row_sum[0][-1]
    x1, x2 = col_sum[0][0], col_sum[0][-1]
    cropped_image = img[y1:y2, x1:x2]
    cropped_image_size = cropped_image.shape
    
    if verbose:
        print('(left x1, top y1):', (x1, y1))
        print('(right x2, bottom y2):', (x2, y2))
        print('cropped_image size:', cropped_image_size)
        
    if type(resize_fix) == int:
        origin_h, origin_w = cropped_image.shape
        if origin_h > origin_w:
            resize_w = int(origin_w * (resize_fix / origin_h))
            resize_h = resize_fix
        else:
            resize_h = int(origin_h * (resize_fix / origin_w))
            resize_w = resize_fix
        
        # resize
        cropped_image = Image.fromarray(cropped_image).resize(size=(resize_h, resize_w))
        cropped_image = normalize_image(cropped_image)
        cropped_image_size = cropped_image.shape
        if verbose:
            print('resized_image size:', cropped_image_size)
        
    elif type(resize_fix) == float:
        origin_h, origin_w = cropped_image.shape
        resize_h, resize_w = int(origin_h * resize_fix), int(origin_w * resize_fix)
        if resize_h > 120:
            resize_h = 120
            resize_w = int(resize_w * 120 / resize_h)
        if resize_w > 120:
            resize_w = 120
            resize_h = int(resize_h * 120 / resize_w)
        
        # resize
        cropped_image = Image.fromarray(cropped_image).resize(size=(resize_h, resize_w))
        cropped_image = normalize_image(cropped_image)
        cropped_image_size = cropped_image.shape
        if verbose:
            print('resized_image size:', cropped_image_size)
    
    return cropped_image

In [32]:
def add_padding(img, image_size=128, verbose=False, pad_value=None):
    height, width = img.shape
    if not pad_value:
        pad_value = img[0][0]
    if verbose:
        print('original cropped image size:', img.shape)
    
    # Adding padding of x axis - left, right
    pad_x_width = (image_size - width) // 2
    pad_x = np.full((height, pad_x_width), pad_value, dtype=np.float32)
    img = np.concatenate((pad_x, img), axis=1)
    img = np.concatenate((img, pad_x), axis=1)
    
    width = img.shape[1]

    # Adding padding of y axis - top, bottom
    pad_y_height = (image_size - height) // 2
    pad_y = np.full((pad_y_height, width), pad_value, dtype=np.float32)
    img = np.concatenate((pad_y, img), axis=0)
    img = np.concatenate((img, pad_y), axis=0)
    
    # Match to original image size
    width = img.shape[1]
    if img.shape[0] % 2:
        pad = np.full((1, width), pad_value, dtype=np.float32)
        img = np.concatenate((pad, img), axis=0)
    height = img.shape[0]
    if img.shape[1] % 2:
        pad = np.full((height, 1), pad_value, dtype=np.float32)
        img = np.concatenate((pad, img), axis=1)

    if verbose:
        print('final image size:', img.shape)
    
    return img

In [33]:
def centering_image(img, image_size=128, verbose=False, resize_fix=False, pad_value=None):
    if not pad_value:
        pad_value = img[0][0]
    cropped_image = tight_crop_image(img, verbose=verbose, resize_fix=resize_fix)
    centered_image = add_padding(cropped_image, image_size=image_size, verbose=verbose, pad_value=pad_value)
    
    return centered_image

In [34]:
def chars_to_ids(sentence):
    charset = []
    for i in range(0xac00,0xd7a4):
        charset.append(chr(i))

    fixed_char_ids = []
    for char in sentence:
        fixed_char_ids.append(charset.index(char))
        
    return fixed_char_ids

In [35]:
def round_function(i):
    if i < -0.95:
        return -1
    elif i > 0.95:
        return 1
    else:
        return i

# Dataset
- load dataset, data pre-processing

In [36]:
from __future__ import print_function
from __future__ import absolute_import
import pickle as pickle
import numpy as np
import random
import os
import torch

In [37]:
def get_batch_iter(examples, batch_size, augment, with_charid=False):
    # the transpose ops requires deterministic
    # batch size, thus comes the padding
    padded = pad_seq(examples, batch_size)

    def process(img):
        img = bytes_to_file(img)
        try:
            img_A, img_B = read_split_image(img)
            if augment:
                # augment the image by:
                # 1) enlarge the image
                # 2) random crop the image back to its original size
                # NOTE: image A and B needs to be in sync as how much
                # to be shifted
                w, h = img_A.shape
                multiplier = random.uniform(1.00, 1.20)
                # add an eps to prevent cropping issue
                nw = int(multiplier * w) + 1
                nh = int(multiplier * h) + 1
                shift_x = int(np.ceil(np.random.uniform(0.01, nw - w)))
                shift_y = int(np.ceil(np.random.uniform(0.01, nh - h)))
                img_A = shift_and_resize_image(img_A, shift_x, shift_y, nw, nh)
                img_B = shift_and_resize_image(img_B, shift_x, shift_y, nw, nh)
            img_A = normalize_image(img_A)
            img_A = img_A.reshape(1, len(img_A), len(img_A[0]))
            img_B = normalize_image(img_B)
            img_B = img_B.reshape(1, len(img_B), len(img_B[0]))
            return np.concatenate([img_A, img_B], axis=0)
        finally:
            img.close()
            
    def batch_iter(with_charid=with_charid):
        for i in range(0, len(padded), batch_size):
            batch = padded[i: i + batch_size]
            labels = [e[0] for e in batch]
            if with_charid:
                charid = [e[1] for e in batch]
                image = [process(e[2]) for e in batch]
                image = np.array(image).astype(np.float32)
                image = torch.from_numpy(image)
                # stack into tensor
                yield [labels, charid, image]
            else:
                image = [process(e[1]) for e in batch]
                image = np.array(image).astype(np.float32)
                image = torch.from_numpy(image)
                # stack into tensor
                yield [labels, image]

    return batch_iter(with_charid=with_charid)

In [38]:
#check _EOFERROR

class PickledImageProvider(object):
    def __init__(self, obj_path, verbose):
        self.obj_path = obj_path
        self.verbose = verbose
        self.examples = self.load_pickled_examples()

    def load_pickled_examples(self):
        with open(self.obj_path, "rb") as of:
            examples = list()
            while True:
                try:
                    e = pickle.load(of)
                    examples.append(e)
                except EOFError:
                    break
                except Exception:
                    pass
            if self.verbose:
                print("unpickled total %d examples" % len(examples))
            return examples

In [39]:
class TrainDataProvider(object):
    def __init__(self, data_dir, train_name="train.obj", val_name="val.obj", \
                 filter_by_font=None, filter_by_charid=None, verbose=True, val=True):
        self.data_dir = data_dir
        self.filter_by_font = filter_by_font
        self.filter_by_charid = filter_by_charid
        self.train_path = os.path.join(self.data_dir, train_name)
        self.val_path = os.path.join(self.data_dir, val_name)
        self.train = PickledImageProvider(self.train_path, verbose)
        if val:
            self.val = PickledImageProvider(self.val_path, verbose)
        if self.filter_by_font:
            if verbose:
                print("filter by label ->", filter_by_font)
            self.train.examples = [e for e in self.train.examples if e[0] in self.filter_by_font]
            if val:
                self.val.examples = [e for e in self.val.examples if e[0] in self.filter_by_font]
        if self.filter_by_charid:
            if verbose:
                print("filter by char ->", filter_by_charid)
            self.train.examples = [e for e in self.train.examples if e[1] in filter_by_charid]
            if val:
                self.val.examples = [e for e in self.val.examples if e[1] in filter_by_charid]
        if verbose:
            if val:
                print("train examples -> %d, val examples -> %d" % (len(self.train.examples), len(self.val.examples)))
            else:
                print("train examples -> %d" % (len(self.train.examples)))

                
    def get_train_iter(self, batch_size, shuffle=True, with_charid=False):
        training_examples = self.train.examples[:]
        if shuffle:
            np.random.shuffle(training_examples)
           
        if with_charid:
            return get_batch_iter(training_examples, batch_size, augment=True, with_charid=True)
        else:
            return get_batch_iter(training_examples, batch_size, augment=True)

        
    def get_val_iter(self, batch_size, shuffle=True, with_charid=False):
        """
        Validation iterator runs forever
        """
        val_examples = self.val.examples[:]
        if shuffle:
            np.random.shuffle(val_examples)
        if with_charid:
            return get_batch_iter(val_examples, batch_size, augment=True, with_charid=True)
        else:
            return get_batch_iter(val_examples, batch_size, augment=True)

        
    def compute_total_batch_num(self, batch_size):
        """Total padded batch num"""
        return int(np.ceil(len(self.train.examples) / float(batch_size)))

    
    def get_all_labels(self):
        """Get all training labels"""
        return list({e[0] for e in self.train.examples})

    
    def get_train_val_path(self):
        return self.train_path, self.val_path



In [40]:
def save_fixed_sample(sample_size, img_size, data_dir, save_dir, \
                      val=False, verbose=True, with_charid=True, resize_fix=90):
    data_provider = TrainDataProvider(data_dir, verbose=verbose, val=val)
    if not val:
        train_batch_iter = data_provider.get_train_iter(sample_size, with_charid=with_charid)
    else:
        train_batch_iter = data_provider.get_val_iter(sample_size, with_charid=with_charid)
        
    for batch in train_batch_iter:
        if with_charid:
            font_ids, _, batch_images = batch
        else:
            font_ids, batch_images = batch
        fixed_batch = batch_images.cuda()
        fixed_source = fixed_batch[:, 1, :, :].reshape(sample_size, 1, img_size, img_size)
        fixed_target = fixed_batch[:, 0, :, :].reshape(sample_size, 1, img_size, img_size)

        # centering
        for idx, (image_S, image_T) in enumerate(zip(fixed_source, fixed_target)):
            image_S = image_S.cpu().detach().numpy().reshape(img_size, img_size)
            image_S = np.array(list(map(round_function, image_S.flatten()))).reshape(128, 128)
            image_S = centering_image(image_S, resize_fix=90)
            fixed_source[idx] = torch.tensor(image_S).view([1, img_size, img_size])
            image_T = image_T.cpu().detach().numpy().reshape(img_size, img_size)
            image_T = np.array(list(map(round_function, image_T.flatten()))).reshape(128, 128)
            image_T = centering_image(image_T, resize_fix=resize_fix)
            fixed_target[idx] = torch.tensor(image_T).view([1, img_size, img_size])

        fixed_label = np.array(font_ids)
        source_with_label = [(label, image_S.cpu().detach().numpy()) \
                             for label, image_S in zip(fixed_label, fixed_source)]
        source_with_label = sorted(source_with_label, key=lambda i: i[0])
        target_with_label = [(label, image_T.cpu().detach().numpy()) \
                             for label, image_T in zip(fixed_label, fixed_target)]
        target_with_label = sorted(target_with_label, key=lambda i: i[0])
        fixed_source = torch.tensor(np.array([i[1] for i in source_with_label])).cuda()
        fixed_target = torch.tensor(np.array([i[1] for i in target_with_label])).cuda()
        fixed_label = sorted(fixed_label)
        torch.save(fixed_source, os.path.join(save_dir, 'fixed_source.pkl'))
        torch.save(fixed_target, os.path.join(save_dir, 'fixed_target.pkl'))
        torch.save(fixed_label, os.path.join(save_dir, 'fixed_label.pkl'))
        return

# Models
- Generator(Encoder, Decoder), Discriminator

In [41]:
import torch
import torch.nn as nn
#from function import conv2d, deconv2d, lrelu, fc, embedding_lookup # .function 수정함
import warnings
warnings.filterwarnings("ignore")

In [42]:
def Generator(images, En, De, embeddings, embedding_ids, GPU=False, encode_layers=False):
    encoded_source, encode_layers = En(images)
    local_embeddings = embedding_lookup(embeddings, embedding_ids, GPU=GPU)
    if GPU:
        encoded_source = encoded_source.cuda()
        local_embeddings = local_embeddings.cuda()
    embedded = torch.cat((encoded_source, local_embeddings), 1)
    fake_target = De(embedded, encode_layers)
    if encode_layers:
        return fake_target, encoded_source, encode_layers
    else:
        return fake_target, encoded_source

In [43]:
class Encoder(nn.Module):
    
    def __init__(self, img_dim=1, conv_dim=64):
        super(Encoder, self).__init__()
        self.conv1 = conv2d(img_dim, conv_dim, k_size=5, stride=2, pad=2, dilation=2, lrelu=False, bn=False)
        self.conv2 = conv2d(conv_dim, conv_dim*2, k_size=5, stride=2, pad=2, dilation=2)
        self.conv3 = conv2d(conv_dim*2, conv_dim*4, k_size=4, stride=2, pad=1, dilation=1)
        self.conv4 = conv2d(conv_dim*4, conv_dim*8)
        self.conv5 = conv2d(conv_dim*8, conv_dim*8)
        self.conv6 = conv2d(conv_dim*8, conv_dim*8)
        self.conv7 = conv2d(conv_dim*8, conv_dim*8)
        self.conv8 = conv2d(conv_dim*8, conv_dim*8)
    
    def forward(self, images):
        encode_layers = dict()
        
        e1 = self.conv1(images)
        encode_layers['e1'] = e1
        e2 = self.conv2(e1)
        encode_layers['e2'] = e2
        e3 = self.conv3(e2)
        encode_layers['e3'] = e3
        e4 = self.conv4(e3)
        encode_layers['e4'] = e4
        e5 = self.conv5(e4)
        encode_layers['e5'] = e5
        e6 = self.conv6(e5)
        encode_layers['e6'] = e6
        e7 = self.conv7(e6)
        encode_layers['e7'] = e7
        encoded_source = self.conv8(e7)
        encode_layers['e8'] = encoded_source
        
        return encoded_source, encode_layers

In [44]:
class Decoder(nn.Module):
    
    def __init__(self, img_dim=1, embedded_dim=640, conv_dim=64):
        super(Decoder, self).__init__()
        self.deconv1 = deconv2d(embedded_dim, conv_dim*8, dropout=True)
        self.deconv2 = deconv2d(conv_dim*16, conv_dim*8, dropout=True, k_size=4)
        self.deconv3 = deconv2d(conv_dim*16, conv_dim*8, k_size=5, dilation=2, dropout=True)
        self.deconv4 = deconv2d(conv_dim*16, conv_dim*8, k_size=4, dilation=2, stride=2)
        self.deconv5 = deconv2d(conv_dim*16, conv_dim*4, k_size=4, dilation=2, stride=2)
        self.deconv6 = deconv2d(conv_dim*8, conv_dim*2, k_size=4, dilation=2, stride=2)
        self.deconv7 = deconv2d(conv_dim*4, conv_dim*1, k_size=4, dilation=2, stride=2)
        self.deconv8 = deconv2d(conv_dim*2, img_dim, k_size=4, dilation=2, stride=2, bn=False)
    
    
    def forward(self, embedded, encode_layers):
        
        d1 = self.deconv1(embedded)
        d1 = torch.cat((d1, encode_layers['e7']), dim=1)
        d2 = self.deconv2(d1)
        d2 = torch.cat((d2, encode_layers['e6']), dim=1)
        d3 = self.deconv3(d2)
        d3 = torch.cat((d3, encode_layers['e5']), dim=1)
        d4 = self.deconv4(d3)
        d4 = torch.cat((d4, encode_layers['e4']), dim=1)
        d5 = self.deconv5(d4)
        d5 = torch.cat((d5, encode_layers['e3']), dim=1)
        d6 = self.deconv6(d5)
        d6 = torch.cat((d6, encode_layers['e2']), dim=1)
        d7 = self.deconv7(d6)
        d7 = torch.cat((d7, encode_layers['e1']), dim=1)
        d8 = self.deconv8(d7)        
        fake_target = torch.tanh(d8)
        
        return fake_target

In [45]:
class Discriminator(nn.Module):
    def __init__(self, category_num, img_dim=2, disc_dim=64):
        super(Discriminator, self).__init__()
        self.conv1 = conv2d(img_dim, disc_dim, bn=False)
        self.conv2 = conv2d(disc_dim, disc_dim*2)
        self.conv3 = conv2d(disc_dim*2, disc_dim*4)
        self.conv4 = conv2d(disc_dim*4, disc_dim*8)
        self.fc1 = fc(disc_dim*8*8*8, 1)
        self.fc2 = fc(disc_dim*8*8*8, category_num)
        
    def forward(self, images):
        batch_size = images.shape[0]
        h1 = self.conv1(images)
        h2 = self.conv2(h1)
        h3 = self.conv3(h2)
        h4 = self.conv4(h3)
        
        tf_loss_logit = self.fc1(h4.reshape(batch_size, -1))
        tf_loss = torch.sigmoid(tf_loss_logit)
        cat_loss = self.fc2(h4.reshape(batch_size, -1))
        
        return tf_loss, tf_loss_logit, cat_loss

# Train
- model Trainer

In [46]:
import os, glob, time, datetime
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torchvision.utils import save_image

In [47]:
class Trainer:
    
    def __init__(self, GPU, data_dir, fixed_dir, fonts_num, batch_size, img_size):
        self.GPU = GPU
        self.data_dir = data_dir
        self.fixed_dir = fixed_dir
        self.fonts_num = fonts_num
        self.batch_size = batch_size
        self.img_size = img_size
        
        self.embeddings = torch.load(os.path.join(fixed_dir, 'EMBEDDINGS.pkl'))
        self.embedding_num = self.embeddings.shape[0]
        self.embedding_dim = self.embeddings.shape[3]
        
        self.fixed_source = torch.load(os.path.join(fixed_dir, 'fixed_source.pkl'))
        self.fixed_target = torch.load(os.path.join(fixed_dir, 'fixed_target.pkl'))
        self.fixed_label = torch.load(os.path.join(fixed_dir, 'fixed_label.pkl'))
        
        self.data_provider = TrainDataProvider(self.data_dir)
        self.total_batches = self.data_provider.compute_total_batch_num(self.batch_size)
        print("total batches:", self.total_batches)


    def train(self, max_epoch, schedule, save_path, to_model_path, lr=0.001, \
              log_step=100, sample_step=350, fine_tune=False, flip_labels=False, \
              restore=None, from_model_path=False, with_charid=False, \
              freeze_encoder=False, save_nrow=8, model_save_step=None, resize_fix=90):

        # Fine Tuning coefficient
        if not fine_tune:
            L1_penalty, Lconst_penalty = 100, 15
        else:
            L1_penalty, Lconst_penalty = 500, 1000

        # Get Models
        En = Encoder()
        De = Decoder()
        D = Discriminator(category_num=self.fonts_num)
        if self.GPU:
            En.cuda()
            De.cuda()
            D.cuda()

        # Use pre-trained Model
        # restore에 [encoder_path, decoder_path, discriminator_path] 형태로 인자 넣기
        if restore:
            encoder_path, decoder_path, discriminator_path = restore
            prev_epoch = int(encoder_path.split('-')[0])
            En.load_state_dict(torch.load(os.path.join(from_model_path, encoder_path)))
            De.load_state_dict(torch.load(os.path.join(from_model_path, decoder_path)))
            D.load_state_dict(torch.load(os.path.join(from_model_path, discriminator_path)))
            print("%d epoch trained model has restored" % prev_epoch)
        else:
            prev_epoch = 0
            print("New model training start")


        # L1 loss, binary real/fake loss, category loss, constant loss
        if self.GPU:
            l1_criterion = nn.L1Loss(size_average=True).cuda()
            bce_criterion = nn.BCEWithLogitsLoss(size_average=True).cuda()
            mse_criterion = nn.MSELoss(size_average=True).cuda()
        else:
            l1_criterion = nn.L1Loss(size_average=True)
            bce_criterion = nn.BCEWithLogitsLoss(size_average=True)
            mse_criterion = nn.MSELoss(size_average=True)


        # optimizer
        if freeze_encoder:
            G_parameters = list(De.parameters())
        else:
            G_parameters = list(En.parameters()) + list(De.parameters())
        g_optimizer = torch.optim.Adam(G_parameters, betas=(0.5, 0.999))
        d_optimizer = torch.optim.Adam(D.parameters(), betas=(0.5, 0.999))

        # losses lists
        l1_losses, const_losses, category_losses, d_losses, g_losses = list(), list(), list(), list(), list()

        # training
        count = 0
        for epoch in range(max_epoch):
            if (epoch + 1) % schedule == 0:
                updated_lr = max(lr/2, 0.0002)
                for param_group in d_optimizer.param_groups:
                    param_group['lr'] = updated_lr
                for param_group in g_optimizer.param_groups:
                    param_group['lr'] = updated_lr
                if lr !=  updated_lr:
                    print("decay learning rate from %.5f to %.5f" % (lr, updated_lr))
                lr = updated_lr

            train_batch_iter = self.data_provider.get_train_iter(self.batch_size, \
                                                            with_charid=with_charid)   
            for i, batch in enumerate(train_batch_iter):
                if with_charid:
                    font_ids, char_ids, batch_images = batch
                else:
                    font_ids, batch_images = batch
                embedding_ids = font_ids
                if self.GPU:
                    batch_images = batch_images.cuda()
                if flip_labels:
                    np.random.shuffle(embedding_ids)

                # target / source images
                real_target = batch_images[:, 0, :, :]
                real_target = real_target.view([self.batch_size, 1, self.img_size, self.img_size])
                real_source = batch_images[:, 1, :, :]
                real_source = real_source.view([self.batch_size, 1, self.img_size, self.img_size])
                
                # centering
                for idx, (image_S, image_T) in enumerate(zip(real_source, real_target)):
                    image_S = image_S.cpu().detach().numpy().reshape(self.img_size, self.img_size)
                    image_S = centering_image(image_S, resize_fix=90)
                    real_source[idx] = torch.tensor(image_S).view([1, self.img_size, self.img_size])
                    image_T = image_T.cpu().detach().numpy().reshape(self.img_size, self.img_size)
                    image_T = centering_image(image_T, resize_fix=resize_fix)
                    real_target[idx] = torch.tensor(image_T).view([1, self.img_size, self.img_size])

                # generate fake image form source image
                fake_target, encoded_source, _ = Generator(real_source, En, De, \
                                                           self.embeddings, embedding_ids, \
                                                           GPU=self.GPU, encode_layers=True)

                real_TS = torch.cat([real_source, real_target], dim=1)
                fake_TS = torch.cat([real_source, fake_target], dim=1)

                # Scoring with Discriminator
                real_score, real_score_logit, real_cat_logit = D(real_TS)
                fake_score, fake_score_logit, fake_cat_logit = D(fake_TS)

                # Get encoded fake image to calculate constant loss
                encoded_fake = En(fake_target)[0]
                const_loss = Lconst_penalty * mse_criterion(encoded_source, encoded_fake)

                # category loss
                real_category = torch.from_numpy(np.eye(self.fonts_num)[embedding_ids]).float()
                if self.GPU:
                    real_category = real_category.cuda()
                real_category_loss = bce_criterion(real_cat_logit, real_category)
                fake_category_loss = bce_criterion(fake_cat_logit, real_category)
                category_loss = 0.5 * (real_category_loss + fake_category_loss)

                # labels
                if self.GPU:
                    one_labels = torch.ones([self.batch_size, 1]).cuda()
                    zero_labels = torch.zeros([self.batch_size, 1]).cuda()
                else:
                    one_labels = torch.ones([self.batch_size, 1])
                    zero_labels = torch.zeros([self.batch_size, 1])

                # binary loss - T/F
                real_binary_loss = bce_criterion(real_score_logit, one_labels)
                fake_binary_loss = bce_criterion(fake_score_logit, zero_labels)
                binary_loss = real_binary_loss + fake_binary_loss

                # L1 loss between real and fake images
                l1_loss = L1_penalty * l1_criterion(real_target, fake_target)

                # cheat loss for generator to fool discriminator
                cheat_loss = bce_criterion(fake_score_logit, one_labels)

                # g_loss, d_loss
                g_loss = cheat_loss + l1_loss + fake_category_loss + const_loss
                d_loss = binary_loss + category_loss

                # train Discriminator
                D.zero_grad()
                d_loss.backward(retain_graph=True)
                d_optimizer.step()

                # train Generator
                En.zero_grad()
                De.zero_grad()
                g_loss.backward(retain_graph=True)
                g_optimizer.step()            

                # loss data
                l1_losses.append(int(l1_loss.data))
                const_losses.append(int(const_loss.data))
                category_losses.append(int(category_loss.data))
                d_losses.append(int(d_loss.data))
                g_losses.append(int(g_loss.data))

                # logging
                if (i+1) % log_step == 0:
                    time_ = time.time()
                    time_stamp = datetime.datetime.fromtimestamp(time_).strftime('%H:%M:%S')
                    log_format = 'Epoch [%d/%d], step [%d/%d], l1_loss: %.4f, d_loss: %.4f, g_loss: %.4f' % \
                                 (int(prev_epoch)+epoch+1, int(prev_epoch)+max_epoch, \
                                  i+1, self.total_batches, l1_loss.item(), d_loss.item(), g_loss.item())
                    print(time_stamp, log_format)

                # save image
                if (i+1) % sample_step == 0:
                    fixed_fake_images = Generator(self.fixed_source, En, De, \
                                                  self.embeddings, self.fixed_label, GPU=self.GPU)[0]
                    save_image(denorm_image(fixed_fake_images.data), \
                               os.path.join(save_path, 'fake_samples-%d-%d.png' % \
                                            (int(prev_epoch)+epoch+1, i+1)), \
                               nrow=save_nrow, pad_value=255)

            if not model_save_step:
                model_save_step = 5
            if (epoch+1) % model_save_step == 0:
                now = datetime.datetime.now()
                now_date = now.strftime("%m%d")
                now_time = now.strftime('%H:%M')
                torch.save(En.state_dict(), os.path.join(to_model_path, \
                                                         '%d-%s-%s-Encoder.pkl' % \
                                                         (int(prev_epoch)+epoch+1, \
                                                          now_date, now_time)))
                torch.save(De.state_dict(), os.path.join(to_model_path, \
                                                         '%d-%s-%s-Decoder.pkl' % \
                                                         (int(prev_epoch)+epoch+1, \
                                                          now_date, now_time)))
                torch.save(D.state_dict(), os.path.join(to_model_path, \
                                                        '%d-%s-%s-Discriminator.pkl' % \
                                                        (int(prev_epoch)+epoch+1, \
                                                         now_date, now_time)))

        # save model
        total_epoch = int(prev_epoch) + int(max_epoch)
        end = datetime.datetime.now()
        end_date = end.strftime("%m%d")
        end_time = end.strftime('%H:%M')
        torch.save(En.state_dict(), os.path.join(to_model_path, \
                                                 '%d-%s-%s-Encoder.pkl' % \
                                                 (total_epoch, end_date, end_time)))
        torch.save(De.state_dict(), os.path.join(to_model_path, \
                                                 '%d-%s-%s-Decoder.pkl' % \
                                                 (total_epoch, end_date, end_time)))
        torch.save(D.state_dict(), os.path.join(to_model_path, \
                                                '%d-%s-%s-Discriminator.pkl' % \
                                                (total_epoch, end_date, end_time)))
        losses = [l1_losses, const_losses, category_losses, d_losses, g_losses]
        torch.save(losses, os.path.join(to_model_path, '%d-losses.pkl' % total_epoch))

        return l1_losses, const_losses, category_losses, d_losses, g_losses

# Main

- preprocessing :
        - pre training      : .ttf -> .png -> .pki
        - transfer learning : .pdf -> .png -> .pki
- pretraining   : Trainer 사용
- training rusult post-processing

src : 고딕체<br>
dst : 왜곡된. 학습하는 폰트

In [48]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [49]:
cd /content/drive/My Drive/capstone_design_3_2

/content/drive/My Drive/capstone_design_3_2


In [50]:
# 손글씨 template 이미지 입력
from PIL import Image

template = Image.open('210_template-1.png')
template = template.convert('L')

if template.size != (1654, 2339):
    teplate = template.resize((1654, 2339))
    
line = [2,2,2,3,2,2,3,2,3,2,2,3,2,2,2]
up = 113
for i in range(14):
    left = 57
    if i in [4,7,9,11,13]:
        up += 1
    for j in range(15):
        croppedImage = template.crop((left, up, left+100, up+100))
        file_name = "./handwriting/" + str(i*15+j+1) + ".png"
        croppedImage.save(file_name)
        left = left + 100 + line[j]
    up += 157

In [129]:
tmpl_target = open("template_target.txt",'r', encoding = 'utf-8')
tmpl_target.close()

In [51]:
label_file = open("2350-common-hangul.txt", 'r', encoding = 'utf-8')
ch = label_file.readline().replace('\n', '')
line = 0
while ch:
    line +=1
    ch = label_file.readline().replace('\n', '')
print (line)

2350


In [52]:
src_font = ImageFont.truetype(SRC_PATH + "source_font.ttf",100)
canvas_size = 128
for index in range(56):
    ttf_index = str(index+1).zfill(2)
    ttf_name = ttf_index+".ttf"
    dst_font = ImageFont.truetype(TRG_PATH + ttf_name ,100)
    
    label_file = open("2350-common-hangul.txt", 'r', encoding = 'utf-8')
    ch =label_file.readline().replace('\n', '')
    line = 1
    while ch:
        made_img = draw_example(ch, src_font, dst_font, canvas_size)
        new_img_name = OUTPUT_PATH+ttf_index+"_"+str(line)+".png"
        made_img.save(new_img_name,format='PNG')
        #다음 줄 읽기
        ch = label_file.readline().replace('\n', '')
        line += 1
    label_file.close()

In [53]:
train_path = OUTPUT_PATH + 'train.obj'
val_path = OUTPUT_PATH + 'val.obj'
pickle_examples(OUTPUT_PATH,train_path,val_path)

all data num: 131600
10000 imgs saved in train.obj
20000 imgs saved in train.obj
30000 imgs saved in train.obj
10000 imgs saved in val.obj
40000 imgs saved in train.obj
50000 imgs saved in train.obj
60000 imgs saved in train.obj
70000 imgs saved in train.obj
20000 imgs saved in val.obj
80000 imgs saved in train.obj
90000 imgs saved in train.obj
100000 imgs saved in train.obj
26357 imgs saved in val.obj, end
105243 imgs saved in train.obj, end


In [57]:
sample_size = 16 #tensor ver (original code) value & 위에 셀 참고
img_size = 128
data_dir = OUTPUT_PATH
save_dir = './dataset_pki/'
save_fixed_sample(sample_size, img_size, data_dir, save_dir)

unpickled total 105243 examples
train examples -> 105243


IndexError: ignored

In [None]:
#GPU check


In [None]:
GPU = # GPU 대여 확인
data_dir = #pickle data가 저장된 곳
fixed_dir = "./pre-result/"
fonts_num = 56
batch_size = 16 #tensor ver original code value
img_size = 128

HandwritingFontML = Trainer(GPU, data_dir, fixed_dir, fonts_num, batch_size, img_size)

In [None]:
def interpolation(data_provider, grids, fixed_char_ids, interpolated_font_ids, embeddings, \
                  En, De, batch_size, img_size=128, save_nrow=6, save_path=False, GPU=True):
    
    train_batch_iter = data_provider.get_train_iter(batch_size, with_charid=True)
    
    for grid_idx, grid in enumerate(grids):
        train_batch_iter = data_provider.get_train_iter(batch_size, with_charid=True)
        grid_results = {from_to: {charid: None for charid in fixed_char_ids} \
                        for from_to in interpolated_font_ids}

        for i, batch in enumerate(train_batch_iter):
            font_ids_from, char_ids, batch_images = batch
            font_filter = [i[0] for i in interpolated_font_ids]
            font_filter_plus = font_filter + [font_filter[0]]
            font_ids_to = [font_filter_plus[font_filter.index(i)+1] for i in font_ids_from]
            batch_images = batch_images.cuda()

            real_sources = batch_images[:, 1, :, :].view([batch_size, 1, img_size, img_size])
            real_targets = batch_images[:, 0, :, :].view([batch_size, 1, img_size, img_size])

            for idx, (image_S, image_T) in enumerate(zip(real_sources, real_targets)):
                image_S = image_S.cpu().detach().numpy().reshape(img_size, img_size)
                image_S = centering_image(image_S, resize_fix=100)
                real_sources[idx] = torch.tensor(image_S).view([1, img_size, img_size])
                image_T = image_T.cpu().detach().numpy().reshape(img_size, img_size)
                image_T = centering_image(image_T, resize_fix=100)
                real_targets[idx] = torch.tensor(image_T).view([1, img_size, img_size])
                
            encoded_source, encode_layers = En(real_sources)

            interpolated_embeddings = []
            embedding_dim = embeddings.shape[3]
            for from_, to_ in zip(font_ids_from, font_ids_to):
                interpolated_embeddings.append((embeddings[from_] * (1 - grid) + \
                                                embeddings[to_] * grid).cpu().numpy())
            interpolated_embeddings = torch.tensor(interpolated_embeddings).cuda()
            interpolated_embeddings = interpolated_embeddings.reshape(batch_size, embedding_dim, 1, 1)

            # generate fake image with embedded source
            interpolated_embedded = torch.cat((encoded_source, interpolated_embeddings), 1)
            fake_targets = De(interpolated_embedded, encode_layers)

            # [(0)real_S, (1)real_T, (2)fake_T]
            for fontid, charid, real_S, real_T, fake_T in zip(font_ids_from, char_ids, \
                                                              real_sources, real_targets, \
                                                              fake_targets):
                font_from = fontid
                font_to = font_filter_plus[font_filter.index(fontid)+1]
                from_to = (font_from, font_to)
                grid_results[from_to][charid] = [real_S, real_T, fake_T]

        if save_path:
            for from_to in grid_results.keys():
                image = [grid_results[from_to][charid][2].cpu().detach().numpy() for \
                         charid in fixed_char_ids]
                image = torch.tensor(np.array(image))

                # path
                font_from = str(from_to[0])
                font_to = str(from_to[1])
                grid_idx = str(grid_idx)
                if len(font_from) == 1:
                    font_from = '0' + font_from
                if len(font_to) == 1:
                    font_to = '0' + font_to
                if len(grid_idx) == 1:
                    grid_idx = '0' + grid_idx
                idx = str(interpolated_font_ids.index(from_to))
                if len(idx) == 1:
                    idx = '0' + idx
                file_path = '%s_from_%s_to_%s_grid_%s.png' % (idx, font_from, font_to, grid_idx)

                # save
                save_image(denorm_image(image.data), \
                           os.path.join(save_path, file_path), \
                           nrow=save_nrow, pad_value=255)
    
    return grid_results