In [1]:
import os, glob, time, datetime
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
from torchvision.utils import save_image

from common.dataset import TrainDataProvider
from common.function import init_embedding
from common.models import Encoder, Decoder, Discriminator, Generator
from common.utils import denorm_image

# New Mindset

### 문제상황
1. 이전의 실험에서 learning rate 때문인지, 무슨이유인지 몰라도 mode collapsing이 발생함


2. 각 폰트가 각자의 label에 맞지 않게 학습이 되는 것으로 보임


### 수정사항
1. 일단은 learning rate를 다시 잡아가면서 학습시켜보기로 함


2. 그리고 데이터를 모든 폰트별로 11712자를 전부 학습시킬 필요는 없는 것 같고, 그 중 `2000자만 랜덤`으로 뽑아서 사용하도록 하겠음


3. 또, `batch_size=16`으로 줄여서 학습하기

### 다시 시작하는 지점
1. 이전에 lr=0.001로 학습시켰던 모델 중 epoch 11까지 학습된 모델에서 다시 시작


2. 그 이후, `20epoch`까지 lr=0.0005로, 30epoch까지 lr=0.00025로 학습시키기

## Dataset 다시 생성하기

In [7]:
# -*- coding: utf-8 -*-
from __future__ import print_function
from __future__ import absolute_import

import argparse
import glob
import os
import pickle as pickle
import random


def pickle_examples(from_dir, train_path, val_path, train_val_split=0.2):
    """
    Compile a list of examples into pickled format, so during
    the training, all io will happen in memory
    """
    paths = glob.glob(os.path.join(from_dir, "*.png"))
    with open(train_path, 'wb') as ft:
        with open(val_path, 'wb') as fv:
            print('all data num:', len(paths))
            c = 1
            val_count = 0
            train_count = 0
            for p in paths:
                c += 1
                label = int(os.path.basename(p).split("_")[0])
                with open(p, 'rb') as f:
                    img_bytes = f.read()
                    example = (label, img_bytes)
                    r = random.random()
                    if r < train_val_split:
                        pickle.dump(example, fv)
                        val_count += 1
                        if val_count % 10000 == 0:
                            print("%d imgs saved in val.obj" % val_count)
                    else:
                        pickle.dump(example, ft)
                        train_count += 1
                        if train_count % 10000 == 0:
                            print("%d imgs saved in train.obj" % train_count)
            print("%d imgs saved in val.obj, end" % val_count)
            print("%d imgs saved in train.obj, end" % train_count)
            return

#### 75000개 수준의 dataset으로 작게 만들어서 다시 생성한다.

In [8]:
from_dir = './get_data/hangul-dataset-11172/'
save_dir = './dataset/'
train_path = os.path.join(save_dir, "train.obj")
val_path = os.path.join(save_dir, "val.obj")

pickle_examples(from_dir, train_path=train_path, val_path=val_path, train_val_split=1-0.33) # 75000

all data num: 226659
10000 imgs saved in val.obj
20000 imgs saved in val.obj
10000 imgs saved in train.obj
30000 imgs saved in val.obj
40000 imgs saved in val.obj
20000 imgs saved in train.obj
50000 imgs saved in val.obj
60000 imgs saved in val.obj
30000 imgs saved in train.obj
70000 imgs saved in val.obj
80000 imgs saved in val.obj
40000 imgs saved in train.obj
90000 imgs saved in val.obj
100000 imgs saved in val.obj
50000 imgs saved in train.obj
110000 imgs saved in val.obj
120000 imgs saved in val.obj
60000 imgs saved in train.obj
130000 imgs saved in val.obj
140000 imgs saved in val.obj
70000 imgs saved in train.obj
150000 imgs saved in val.obj
151911 imgs saved in val.obj, end
74748 imgs saved in train.obj, end


> training dataset : `74748`

- dataset spec: `25fonts`, `3000chars /fonts`

### GPU Option

In [2]:
GPU = torch.cuda.is_available()
GPU

True

### Path Setting

In [3]:
data_dir = './dataset/'
model_dir = './model_save/'
fixed_dir = './fixed_sample'

### Get Fixed Embedding

In [4]:
embeddings = torch.load(os.path.join(fixed_dir, 'EMBEDDINGS.pkl'))
embeddings.shape

torch.Size([100, 1, 1, 128])

### Get Fixed sample

In [5]:
fixed_sources, fixed_targets, fixed_labels = [], [], []

# font별 fixed target
for i in range(25):
    source = torch.load(os.path.join(fixed_dir, 'fixed_source_%d.pkl' % i))
    target = torch.load(os.path.join(fixed_dir, 'fixed_target_%d.pkl' % i))
    label = torch.load(os.path.join(fixed_dir, 'fixed_label_%d.pkl' % i))
    fixed_sources.append(source)
    fixed_targets.append(target)
    fixed_labels.append(label)
    
# 모든 폰트가 섞여있는 target
source = torch.load(os.path.join(fixed_dir, 'fixed_source_all.pkl'))
target = torch.load(os.path.join(fixed_dir, 'fixed_target_all.pkl'))
label = torch.load(os.path.join(fixed_dir, 'fixed_label_all.pkl'))
fixed_sources.append(source)
fixed_targets.append(target)
fixed_labels.append(label)

print("fixed sources:", len(fixed_sources))
print("fixed targets:", len(fixed_targets))
print("fixed labels:", len(fixed_labels))

fixed sources: 26
fixed targets: 26
fixed labels: 26


### fixed_source는 일단 폰트 다 섞여있는 걸로 지정

In [6]:
fixed_source = fixed_sources[-1]
fixed_target = fixed_targets[-1]
fixed_label = fixed_labels[-1]

### Hyper Parameter Setting

- batch_size 16으로 줄이기

In [7]:
FONTS_NUM = 25
EMBEDDING_NUM = 100
BATCH_SIZE = 16
IMG_SIZE = 128
EMBEDDING_DIM = 128

### Data Provider

In [8]:
data_provider = TrainDataProvider(data_dir)
total_batches = data_provider.compute_total_batch_num(BATCH_SIZE)
print("total batches:", total_batches)

unpickled total 74748 examples
unpickled total 151911 examples
train examples -> 74748, val examples -> 151911
total batches: 4672


In [9]:
def train(max_epoch, schedule, data_dir, save_path, to_model_path, lr=0.001, \
          log_step=100, sample_step=350, fine_tune=False, flip_labels=False, \
          restore=None, from_model_path=False, GPU=True):
    
    # Fine Tuning coefficient
    if not fine_tune:
        L1_penalty, Lconst_penalty = 100, 15
    else:
        L1_penalty, Lconst_penalty = 500, 1000

    # Get Models
    En = Encoder()
    De = Decoder()
    D = Discriminator(category_num=FONTS_NUM)
    if GPU:
        En.cuda()
        De.cuda()
        D.cuda()
    
    # Use pre-trained Model
    # restore에 [encoder_path, decoder_path, discriminator_path] 형태로 인자 넣기
    if restore:
        encoder_path, decoder_path, discriminator_path = restore
        prev_epoch = int(encoder_path.split('-')[0])
        En.load_state_dict(torch.load(os.path.join(from_model_path, encoder_path)))
        De.load_state_dict(torch.load(os.path.join(from_model_path, decoder_path)))
        D.load_state_dict(torch.load(os.path.join(from_model_path, discriminator_path)))
        print("%d epoch trained model has restored" % prev_epoch)
    else:
        prev_epoch = 0
        print("New model training start")

        
    # L1 loss, binary real/fake loss, category loss, constant loss
    if GPU:
        l1_criterion = nn.L1Loss(size_average=True).cuda()
        bce_criterion = nn.BCEWithLogitsLoss(size_average=True).cuda()
        mse_criterion = nn.MSELoss(size_average=True).cuda()
    else:
        l1_criterion = nn.L1Loss(size_average=True)
        bce_criterion = nn.BCEWithLogitsLoss(size_average=True)
        mse_criterion = nn.MSELoss(size_average=True)


    # optimizer
    G_parameters = list(En.parameters()) + list(De.parameters())
    g_optimizer = torch.optim.Adam(G_parameters, betas=(0.5, 0.999))
    d_optimizer = torch.optim.Adam(D.parameters(), betas=(0.5, 0.999))
    
    # losses lists
    l1_losses, const_losses, category_losses, d_losses, g_losses = list(), list(), list(), list(), list()
    
    # training
    count = 0
    for epoch in range(max_epoch):
        if (epoch + 1) % schedule == 0:
            updated_lr = max(lr/2, 0.0002)
            for param_group in d_optimizer.param_groups:
                param_group['lr'] = updated_lr
            for param_group in g_optimizer.param_groups:
                param_group['lr'] = updated_lr
            if lr !=  updated_lr:
                print("decay learning rate from %.5f to %.5f" % (lr, updated_lr))
            lr = updated_lr
            
        train_batch_iter = data_provider.get_train_iter(BATCH_SIZE)   
        for i, batch in enumerate(train_batch_iter):
            labels, batch_images = batch
            embedding_ids = labels
            if GPU:
                batch_images = batch_images.cuda()
            if flip_labels:
                np.random.shuffle(embedding_ids)
                
            # target / source images
            real_target = batch_images[:, 0, :, :].view([BATCH_SIZE, 1, IMG_SIZE, IMG_SIZE])
            real_source = batch_images[:, 1, :, :].view([BATCH_SIZE, 1, IMG_SIZE, IMG_SIZE])
            
            # generate fake image form source image
            fake_target, encoded_source = Generator(real_source, En, De, embeddings, embedding_ids, GPU=GPU)
            
            real_TS = torch.cat([real_source, real_target], dim=1)
            fake_TS = torch.cat([real_source, fake_target], dim=1)
            
            # Scoring with Discriminator
            real_score, real_score_logit, real_cat_logit = D(real_TS)
            fake_score, fake_score_logit, fake_cat_logit = D(fake_TS)
            
            # Get encoded fake image to calculate constant loss
            encoded_fake = En(fake_target)[0]
            const_loss = Lconst_penalty * mse_criterion(encoded_source, encoded_fake)
            
            # category loss
            real_category = torch.from_numpy(np.eye(FONTS_NUM)[embedding_ids]).float()
            if GPU:
                real_category = real_category.cuda()
            real_category_loss = bce_criterion(real_cat_logit, real_category)
            fake_category_loss = bce_criterion(fake_cat_logit, real_category)
            category_loss = 0.5 * (real_category_loss + fake_category_loss)
            
            # labels
            if GPU:
                one_labels = torch.ones([BATCH_SIZE, 1]).cuda()
                zero_labels = torch.zeros([BATCH_SIZE, 1]).cuda()
            else:
                one_labels = torch.ones([BATCH_SIZE, 1])
                zero_labels = torch.zeros([BATCH_SIZE, 1])
            
            # binary loss - T/F
            real_binary_loss = bce_criterion(real_score_logit, one_labels)
            fake_binary_loss = bce_criterion(fake_score_logit, zero_labels)
            binary_loss = real_binary_loss + fake_binary_loss
            
            # L1 loss between real and fake images
            l1_loss = L1_penalty * l1_criterion(real_target, fake_target)
            
            # cheat loss for generator to fool discriminator
            cheat_loss = bce_criterion(fake_score_logit, one_labels)
            
            # g_loss, d_loss
            g_loss = cheat_loss + l1_loss + fake_category_loss + const_loss
            d_loss = binary_loss + category_loss
            
            # train Discriminator
            D.zero_grad()
            d_loss.backward(retain_graph=True)
            d_optimizer.step()
            
            # train Generator
            En.zero_grad()
            De.zero_grad()
            g_loss.backward(retain_graph=True)
            g_optimizer.step()            
            
            # loss data
            l1_losses.append(l1_loss.data)
            const_losses.append(const_loss.data)
            category_losses.append(category_loss.data)
            d_losses.append(d_loss.data)
            g_losses.append(g_loss.data)
            
            # logging
            if (i+1) % log_step == 0:
                time_ = time.time()
                time_stamp = datetime.datetime.fromtimestamp(time_).strftime('%H:%M:%S')
                log_format = 'Epoch [%d/%d], step [%d/%d], l1_loss: %.4f, d_loss: %.4f, g_loss: %.4f' % \
                             (int(prev_epoch)+epoch+1, int(prev_epoch)+max_epoch, i+1, total_batches, \
                              l1_loss.item(), d_loss.item(), g_loss.item())
                print(time_stamp, log_format)
                
            # save image
            if (i+1) % sample_step == 0:
                fixed_fake_images = Generator(fixed_source, En, De, embeddings, fixed_label, GPU=GPU)[0]
                save_image(denorm_image(fixed_fake_images.data), \
                           os.path.join(save_path, 'fake_samples-%d-%d.png' % (int(prev_epoch)+epoch+1, i+1)), \
                           nrow=8)
        
        if (epoch+1) % 5 == 0:
            now = datetime.datetime.now()
            now_date = now.strftime("%m%d")
            now_time = now.strftime('%H:%M')
            torch.save(En.state_dict(), os.path.join(to_model_path, '%d-%s-%s-Encoder.pkl' \
                                                     % (int(prev_epoch)+epoch+1, now_date, now_time)))
            torch.save(De.state_dict(), os.path.join(to_model_path, '%d-%s-%s-Decoder.pkl' % \
                                                     (int(prev_epoch)+epoch+1, now_date, now_time)))
            torch.save(D.state_dict(), os.path.join(to_model_path, '%d-%s-%s-Discriminator.pkl' % \
                                                    (int(prev_epoch)+epoch+1, now_date, now_time)))

    # save model
    total_epoch = int(prev_epoch) + int(max_epoch)
    end = datetime.datetime.now()
    end_date = end.strftime("%m%d")
    end_time = end.strftime('%H:%M')
    torch.save(En.state_dict(), os.path.join(to_model_path, \
                                             '%d-%s-%s-Encoder.pkl' % (total_epoch, end_date, end_time)))
    torch.save(De.state_dict(), os.path.join(to_model_path, \
                                             '%d-%s-%s-Decoder.pkl' % (total_epoch, end_date, end_time)))
    torch.save(D.state_dict(), os.path.join(to_model_path, \
                                            '%d-%s-%s-Discriminator.pkl' % (total_epoch, end_date, end_time)))
    losses = [l1_losses, const_losses, category_losses, d_losses, g_losses]
    torch.save(losses, os.path.join(to_model_path, '%d-losses.pkl' % max_epoch))

    return l1_losses, const_losses, category_losses, d_losses, g_losses

### `lr=0.001` 10epoch / `lr=0.0005` 11~20epoch / `lr=0.00025` 21~30epoch

In [10]:
save_path = './fixed_fake/'
to_model_path = './model_checkpoint/'
losses = train(max_epoch=30, schedule=10, data_dir=data_dir, save_path=save_path, \
               to_model_path=to_model_path, log_step=500, sample_step=500)

New model training start
13:45:26 Epoch [1/30], step [500/4672], l1_loss: 31.5673, d_loss: 4.4805, g_loss: 32.0776
13:46:58 Epoch [1/30], step [1000/4672], l1_loss: 28.3442, d_loss: 1.5311, g_loss: 32.3137
13:48:29 Epoch [1/30], step [1500/4672], l1_loss: 29.9001, d_loss: 1.9606, g_loss: 30.4918
13:50:01 Epoch [1/30], step [2000/4672], l1_loss: 23.4462, d_loss: 0.8900, g_loss: 26.2444
13:51:32 Epoch [1/30], step [2500/4672], l1_loss: 23.6074, d_loss: 1.4519, g_loss: 24.5044
13:53:03 Epoch [1/30], step [3000/4672], l1_loss: 23.3422, d_loss: 0.5449, g_loss: 24.4996
13:54:35 Epoch [1/30], step [3500/4672], l1_loss: 23.8832, d_loss: 0.0776, g_loss: 30.3613
13:56:06 Epoch [1/30], step [4000/4672], l1_loss: 20.8959, d_loss: 0.1440, g_loss: 27.7757
13:57:37 Epoch [1/30], step [4500/4672], l1_loss: 22.4034, d_loss: 0.1014, g_loss: 26.6139
13:59:40 Epoch [2/30], step [500/4672], l1_loss: 21.5676, d_loss: 0.4841, g_loss: 23.0573
14:01:12 Epoch [2/30], step [1000/4672], l1_loss: 25.3525, d_loss: 

16:07:42 Epoch [11/30], step [500/4672], l1_loss: 17.3532, d_loss: 0.0487, g_loss: 27.4139
16:09:13 Epoch [11/30], step [1000/4672], l1_loss: 19.7015, d_loss: 0.0243, g_loss: 28.5345
16:10:44 Epoch [11/30], step [1500/4672], l1_loss: 19.1452, d_loss: 0.0408, g_loss: 27.5213
16:12:16 Epoch [11/30], step [2000/4672], l1_loss: 16.7137, d_loss: 0.0701, g_loss: 23.5110
16:13:47 Epoch [11/30], step [2500/4672], l1_loss: 16.5727, d_loss: 0.0195, g_loss: 24.1229
16:15:19 Epoch [11/30], step [3000/4672], l1_loss: 16.8894, d_loss: 0.0256, g_loss: 24.2029
16:16:50 Epoch [11/30], step [3500/4672], l1_loss: 17.1674, d_loss: 0.0146, g_loss: 24.1146
16:18:22 Epoch [11/30], step [4000/4672], l1_loss: 18.0021, d_loss: 0.0129, g_loss: 25.1344
16:19:53 Epoch [11/30], step [4500/4672], l1_loss: 18.8507, d_loss: 0.0727, g_loss: 23.4336
16:21:56 Epoch [12/30], step [500/4672], l1_loss: 16.8270, d_loss: 0.0578, g_loss: 24.6033
16:23:28 Epoch [12/30], step [1000/4672], l1_loss: 19.2569, d_loss: 0.0154, g_loss

18:28:17 Epoch [20/30], step [4500/4672], l1_loss: 15.4743, d_loss: 0.0193, g_loss: 26.0741
18:30:20 Epoch [21/30], step [500/4672], l1_loss: 17.8909, d_loss: 0.0137, g_loss: 26.6243
18:31:52 Epoch [21/30], step [1000/4672], l1_loss: 17.7067, d_loss: 0.1211, g_loss: 20.6086
18:33:24 Epoch [21/30], step [1500/4672], l1_loss: 19.2615, d_loss: 0.0169, g_loss: 28.4378
18:34:55 Epoch [21/30], step [2000/4672], l1_loss: 18.0181, d_loss: 0.0195, g_loss: 24.2373
18:36:27 Epoch [21/30], step [2500/4672], l1_loss: 18.4976, d_loss: 0.0127, g_loss: 24.6040
18:37:58 Epoch [21/30], step [3000/4672], l1_loss: 19.7455, d_loss: 0.0232, g_loss: 28.4466
18:39:30 Epoch [21/30], step [3500/4672], l1_loss: 16.4863, d_loss: 0.2534, g_loss: 18.6543
18:41:01 Epoch [21/30], step [4000/4672], l1_loss: 16.4772, d_loss: 0.0108, g_loss: 26.4275
18:42:33 Epoch [21/30], step [4500/4672], l1_loss: 17.0904, d_loss: 0.0183, g_loss: 25.6196
18:44:35 Epoch [22/30], step [500/4672], l1_loss: 15.9008, d_loss: 0.0151, g_loss

20:49:14 Epoch [30/30], step [4000/4672], l1_loss: 19.1727, d_loss: 0.0100, g_loss: 25.7169
20:50:46 Epoch [30/30], step [4500/4672], l1_loss: 19.1303, d_loss: 0.0019, g_loss: 30.1648


- 약 14분/1epoch : 2시간20분/10epoch


- 7시간/30epoch