In [9]:
#-*- coding: utf-8 -*-
import numpy as np
import tensorflow as tf
from tensorflow.models.rnn import rnn_cell
from tensorflow.models.rnn import seq2seq

import collections
import argparse
import time
import os

from six.moves import cPickle

In [None]:
# %load Hangulpy/Hangulpy.py
#!/usr/bin/env python
# encoding: utf-8
"""
Hangulpy.py

Copyright (C) 2012 Ryan Rho, Hyunwoo Cho
Text Decompose & Automata Extention by bluedisk@gmail

Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
of the Software, and to permit persons to whom the Software is furnished to do
so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

import string

################################################################################
# Hangul Unicode Variables
################################################################################

# Code = 0xAC00 + (Chosung_index * NUM_JOONGSUNGS * NUM_JONGSUNGS) + (Joongsung_index * NUM_JONGSUNGS) + (Jongsung_index)
CHOSUNGS = [u'ㄱ',u'ㄲ',u'ㄴ',u'ㄷ',u'ㄸ',u'ㄹ',u'ㅁ',u'ㅂ',u'ㅃ',u'ㅅ',u'ㅆ',u'ㅇ',u'ㅈ',u'ㅉ',u'ㅊ',u'ㅋ',u'ㅌ',u'ㅍ',u'ㅎ']
JOONGSUNGS = [u'ㅏ',u'ㅐ',u'ㅑ',u'ㅒ',u'ㅓ',u'ㅔ',u'ㅕ',u'ㅖ',u'ㅗ',u'ㅘ',u'ㅙ',u'ㅚ',u'ㅛ',u'ㅜ',u'ㅝ',u'ㅞ',u'ㅟ',u'ㅠ',u'ㅡ',u'ㅢ',u'ㅣ']
JONGSUNGS = [u'',u'ㄱ',u'ㄲ',u'ㄳ',u'ㄴ',u'ㄵ',u'ㄶ',u'ㄷ',u'ㄹ',u'ㄺ',u'ㄻ',u'ㄼ',u'ㄽ',u'ㄾ',u'ㄿ',u'ㅀ',u'ㅁ',u'ㅂ',u'ㅄ',u'ㅅ',u'ㅆ',u'ㅇ',u'ㅈ',u'ㅊ',u'ㅋ',u'ㅌ',u'ㅍ',u'ㅎ']

# 코딩 효율과 가독성을 위해서 index대신 unicode사용 by bluedisk
JONG_COMP = {
    u'ㄱ':{
        u'ㄱ': u'ㄲ',
        u'ㅅ': u'ㄳ',
    },
    u'ㄴ':{
        u'ㅈ': u'ㄵ',
        u'ㅎ': u'ㄶ',
    },
    u'ㄹ':{
        u'ㄱ': u'ㄺ',
        u'ㅁ': u'ㄻ',
        u'ㅂ': u'ㄼ',
        u'ㅅ': u'ㄽ',
        u'ㅌ': u'ㄾ',
        u'ㅍ': u'ㄿ',
        u'ㅎ': u'ㅀ',
    }
}

NUM_CHOSUNGS = 19
NUM_JOONGSUNGS = 21
NUM_JONGSUNGS = 28

FIRST_HANGUL_UNICODE = 0xAC00 #'가'
LAST_HANGUL_UNICODE = 0xD7A3 #'힣'

# 한자와 라틴 문자 범위 by bluedisk
FIRST_HANJA_UNICODE = 0x4E00
LAST_HANJA_UNICODE = 0x9FFF

FIRST_HANJA_EXT_A_UNICODE = 0x3400
LAST_HANJA_EXT_A_UNICODE = 0x4DBF

FIRST_LATIN1_UNICODE = 0x0000 # NUL
LAST_LATIN1_UNICODE = 0x00FF # 'ÿ'

# EXT B~E 는 무시

################################################################################
# Hangul Automata functions by bluedisk@gmail.com
################################################################################
COMPOSE_CODE = u'ᴥ'

def decompose_text(text, latin_filter=True):
    result=u""

    for c in list(text):
        if is_hangul(c):

            result = result + "".join(decompose(c)) + COMPOSE_CODE

        else:
            if latin_filter:    # 한글 외엔 Latin1 범위까지만 포함 (한글+영어)
                if is_latin1(c):
                    result = result + c
            else:
                result = result + c

    return result

def automata(text):
    res_text = u""
    status="CHO"

    for c in text:

        if status == "CHO":

            if c in CHOSUNGS:
                chosung = c
                status="JOONG"
            else:
                if c != COMPOSE_CODE:

                    res_text = res_text + c

        elif status == "JOONG":

            if c != COMPOSE_CODE and c in JOONGSUNGS:
                joongsung = c
                status="JONG1"
            else:
                res_text = res_text + chosung

                if c in CHOSUNGS:
                    chosung = c
                    status="JOONG"
                else:
                    if c != COMPOSE_CODE:

                        res_text = res_text + c
                    status="CHO"

        elif status == "JONG1":

            if c != COMPOSE_CODE and c in JONGSUNGS:
                jongsung = c

                if c in JONG_COMP:
                    status="JONG2"
                else:
                    res_text = res_text + compose(chosung, joongsung, jongsung)
                    status="CHO"

            else:
                res_text = res_text + compose(chosung, joongsung)

                if c in CHOSUNGS:
                    chosung = c
                    status="JOONG"
                else:
                    if c != COMPOSE_CODE:

                        res_text = res_text + c

                    status="CHO"

        elif status == "JONG2":

            if c != COMPOSE_CODE and c in JONG_COMP[jongsung]:
                jongsung = JONG_COMP[jongsung][c]
                c = COMPOSE_CODE # 종성 재 출력 방지

            res_text = res_text + compose(chosung, joongsung, jongsung)

            if c != COMPOSE_CODE:

                res_text = res_text + c

            status="CHO"


    return res_text

################################################################################
# Boolean Hangul functions
################################################################################

def is_hangul(phrase):
    """Check whether the phrase is Hangul.
    This method ignores white spaces, punctuations, and numbers.
    @param phrase a target string
    @return True if the phrase is Hangul. False otherwise."""

    # If the input is only one character, test whether the character is Hangul.
    if len(phrase) == 1: return is_all_hangul(phrase)

    # Remove all white spaces, punctuations, numbers.
    exclude = set(string.whitespace + string.punctuation + '0123456789')
    phrase = ''.join(ch for ch in phrase if ch not in exclude)

    return is_all_hangul(phrase)

def is_all_hangul(phrase):
    """Check whether the phrase contains all Hangul letters
    @param phrase a target string
    @return True if the phrase only consists of Hangul. False otherwise."""

    for unicode_value in map(lambda letter:ord(letter), phrase):
        if unicode_value < FIRST_HANGUL_UNICODE or unicode_value > LAST_HANGUL_UNICODE:
            # Check whether the letter is chosungs, joongsungs, or jongsungs.
            if unicode_value not in map(lambda v: ord(v), CHOSUNGS + JOONGSUNGS + JONGSUNGS[1:]):
                return False
    return True

def is_hanja(phrase):
    for unicode_value in map(lambda letter:ord(letter), phrase):
        if ((unicode_value < FIRST_HANJA_UNICODE or unicode_value > LAST_HANJA_UNICODE) and
            (unicode_value < FIRST_HANJA_EXT_A_UNICODE or unicode_value > LAST_HANJA_EXT_A_UNICODE)):
            return False
    return True

def is_latin1(phrase):

    for unicode_value in map(lambda letter:ord(letter), phrase):
        if unicode_value < FIRST_LATIN1_UNICODE or unicode_value > LAST_LATIN1_UNICODE:
            return False
    return True


def has_jongsung(letter):
    """Check whether this letter contains Jongsung"""
    if len(letter) != 1:
        raise Exception('The target string must be one letter.')
    if not is_hangul(letter):
        raise NotHangulException('The target string must be Hangul')

    unicode_value = ord(letter)
    return (unicode_value - FIRST_HANGUL_UNICODE) % NUM_JONGSUNGS > 0

def has_batchim(letter):
    """This method is the same as has_jongsung()"""
    return has_jongsung(letter)

def has_approximant(letter):
    """Approximant makes complex vowels, such as ones starting with y or w.
    In Korean there is a unique approximant euㅡ making uiㅢ, but ㅢ does not make many irregularities."""
    if len(letter) != 1:
        raise Exception('The target string must be one letter.')
    if not is_hangul(letter):
        raise NotHangulException('The target string must be Hangul')

    jaso = decompose(letter)
    diphthong = (2, 3, 6, 7, 9, 10, 12, 14, 15, 17)
    # [u'ㅑ',u'ㅒ',',u'ㅕ',u'ㅖ',u'ㅘ',u'ㅙ',u'ㅛ',u'ㅝ',u'ㅞ',u'ㅠ']
    # excluded 'ㅢ' because y- and w-based complex vowels are irregular.
    # vowels with umlauts (ㅐ, ㅔ, ㅚ, ㅟ) are not considered complex vowels.
    return jaso[1] in diphthong

################################################################################
# Decomposition & Combination
################################################################################

def compose(chosung, joongsung, jongsung=u''):
    """This function returns a Hangul letter by composing the specified chosung, joongsung, and jongsung.
    @param chosung
    @param joongsung
    @param jongsung the terminal Hangul letter. This is optional if you do not need a jongsung."""

    if jongsung is None: jongsung = u''

    try:
        chosung_index = CHOSUNGS.index(chosung)
        joongsung_index = JOONGSUNGS.index(joongsung)
        jongsung_index = JONGSUNGS.index(jongsung)
    except Exception, e:
        raise NotHangulException('No valid Hangul character can be generated using given combination of chosung, joongsung, and jongsung.')

    return unichr(0xAC00 + chosung_index * NUM_JOONGSUNGS * NUM_JONGSUNGS + joongsung_index * NUM_JONGSUNGS + jongsung_index)

def decompose(hangul_letter):
    """This function returns letters by decomposing the specified Hangul letter."""

    if len(hangul_letter) < 1:
        raise NotLetterException('')
    elif not is_hangul(hangul_letter):
        raise NotHangulException('')

    code = ord(hangul_letter) - FIRST_HANGUL_UNICODE
    jongsung_index = code % NUM_JONGSUNGS
    code /= NUM_JONGSUNGS
    joongsung_index = code % NUM_JOONGSUNGS
    code /= NUM_JOONGSUNGS
    chosung_index = code

    if chosung_index < 0:
        chosung_index = 0

    try:
        return (CHOSUNGS[chosung_index], JOONGSUNGS[joongsung_index], JONGSUNGS[jongsung_index])
    except:
        print "%d / %d  / %d"%(chosung_index, joongsung_index, jongsung_index)
        print "%s / %s " %( (JOONGSUNGS[joongsung_index].encode("utf8"), JONGSUNGS[jongsung_index].encode('utf8')))
        raise Exception()

################################################################################
# Josa functions
################################################################################

def josa_en(word):
    """add josa either '은' or '는' at the end of this word"""
    word = word.strip()
    if not is_hangul(word): raise NotHangulException('')

    last_letter = word[-1]
    josa = u'은' if has_jongsung(last_letter) else u'는'
    return word + josa

def josa_eg(word):
    """add josa either '이' or '가' at the end of this word"""
    word = word.strip()
    if not is_hangul(word): raise NotHangulException('')

    last_letter = word[-1]
    josa = u'이' if has_jongsung(last_letter) else u'가'
    return word + josa

def josa_el(word):
    """add josa either '을' or '를' at the end of this word"""
    word = word.strip()
    if not is_hangul(word): raise NotHangulException('')

    last_letter = word[-1]
    josa = u'을' if has_jongsung(last_letter) else u'를'
    return word + josa

def josa_ro(word):
    """add josa either '으로' or '로' at the end of this word"""
    word = word.strip()
    if not is_hangul(word): raise NotHangulException('')

    last_letter = word[-1]
    if not has_jongsung(last_letter):
        josa = u'로'
    elif (ord(last_letter) - FIRST_HANGUL_UNICODE) % NUM_JONGSUNGS == 9: # ㄹ
        josa = u'로'
    else:
        josa = u'으로'

    return word + josa

def josa_gwa(word):
    """add josa either '과' or '와' at the end of this word"""
    word = word.strip()
    if not is_hangul(word): raise NotHangulException('')

    last_letter = word[-1]
    josa = u'과' if has_jongsung(last_letter) else u'와'
    return word + josa

def josa_ida(word):
    """add josa either '이다' or '다' at the end of this word"""
    word = word.strip()
    if not is_hangul(word): raise NotHangulException('')

    last_letter = word[-1]
    josa = u'이다' if has_jongsung(last_letter) else u'다'
    return word + josa

################################################################################
# Prefixes and suffixes
# Practice area; need more organization
################################################################################

def add_ryul(word):
    """add suffix either '률' or '율' at the end of this word"""
    word = word.strip()
    if not is_hangul(word): raise NotHangulException('')

    last_letter = word[-1]
    if not has_jongsung(last_letter):
        ryul = u'율'
    elif (ord(last_letter) - FIRST_HANGUL_UNICODE) % NUM_JONGSUNGS == 4: # ㄴ
        ryul = u'율'
    else:
        ryul = u'률'

    return word + ryul


################################################################################
# The formatter, or ultimately, a template system
# Practice area; need more organization
################################################################################

def ili(word):
    """convert {가} or {이} to their correct respective particles automagically."""
    word = word.strip()
    if not is_hangul(word): raise NotHangulException('')

    last_letter = word[word.find(u'{가}')-1]
    word = word.replace(u'{가}', (u'이' if has_jongsung(last_letter) else u'가'))

    last_letter = word[word.find(u'{이}')-1]
    word = word.replace(u'{이}', (u'이' if has_jongsung(last_letter) else u'가'))
    return word

################################################################################
# Exceptions
################################################################################

class NotHangulException(Exception):
    pass

class NotLetterException(Exception):
    pass

class NotWordException(Exception):
    pass


In [3]:
# This is for loading TEXT!
class TextLoader():
    def __init__(self, data_dir, batch_size, seq_length):
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.seq_length = seq_length

        input_file = os.path.join(data_dir, "input.txt")
        vocab_file = os.path.join(data_dir, "vocab.pkl")
        tensor_file = os.path.join(data_dir, "data.npy")

        if not (os.path.exists(vocab_file) and os.path.exists(tensor_file)):
            print("reading text file")
            self.preprocess(input_file, vocab_file, tensor_file)
        else:
            print("loading preprocessed files")
            self.load_preprocessed(vocab_file, tensor_file)
        self.create_batches()
        self.reset_batch_pointer()

    def preprocess(self, input_file, vocab_file, tensor_file):
        with open(input_file, "r") as f:
            data = f.read()
        counter = collections.Counter(data)
        count_pairs = sorted(counter.items(), key=lambda x: -x[1])
        self.chars, _ = zip(*count_pairs)
        self.vocab_size = len(self.chars)
        self.vocab = dict(zip(self.chars, range(len(self.chars))))
        with open(vocab_file, 'wb') as f:
            cPickle.dump(self.chars, f)
        self.tensor = np.array(list(map(self.vocab.get, data)))
        np.save(tensor_file, self.tensor)

    def load_preprocessed(self, vocab_file, tensor_file):
        with open(vocab_file, 'rb') as f:
            self.chars = cPickle.load(f)
        self.vocab_size = len(self.chars)
        self.vocab = dict(zip(self.chars, range(len(self.chars))))
        self.tensor = np.load(tensor_file)
        self.num_batches = int(self.tensor.size / (self.batch_size *
                                                   self.seq_length))

    def create_batches(self):
        self.num_batches = int(self.tensor.size / (self.batch_size *
                                                   self.seq_length))
        self.tensor = self.tensor[:self.num_batches * self.batch_size * self.seq_length]
        xdata = self.tensor
        ydata = np.copy(self.tensor)
        ydata[:-1] = xdata[1:]
        ydata[-1] = xdata[0]
        self.x_batches = np.split(xdata.reshape(self.batch_size, -1), self.num_batches, 1)
        self.y_batches = np.split(ydata.reshape(self.batch_size, -1), self.num_batches, 1)


    def next_batch(self):
        x, y = self.x_batches[self.pointer], self.y_batches[self.pointer]
        self.pointer += 1
        return x, y

    def reset_batch_pointer(self):
        self.pointer = 0

In [4]:
# Load text 
batch_size  = 50
seq_length  = 50
data_dir    = "data/han1"
#data_dir    = "data/han2"
#data_dir    = "data/han3"
#data_dir    = "data/han4"
#data_dir    = "data/han5"

data_loader = TextLoader(data_dir, batch_size, seq_length)
print ("Done")


loading preprocessed files
Done


In [5]:
# Define Network 
rnn_size   = 128
num_layers = 2
grad_clip  = 5.

_batch_size = 1
_seq_length = 1

vocab_size = data_loader.vocab_size

# Select RNN Cell
unitcell = rnn_cell.BasicLSTMCell(rnn_size)
cell = rnn_cell.MultiRNNCell([unitcell] * num_layers)

# Set paths to the graph 
input_data = tf.placeholder(tf.int32, [_batch_size, _seq_length])
targets    = tf.placeholder(tf.int32, [_batch_size, _seq_length])
initial_state = cell.zero_state(_batch_size, tf.float32)

# Set Network
with tf.variable_scope('rnnlm'):
    softmax_w = tf.get_variable("softmax_w", [rnn_size, vocab_size])
    softmax_b = tf.get_variable("softmax_b", [vocab_size])
    with tf.device("/cpu:0"):
        embedding = tf.get_variable("embedding", [vocab_size, rnn_size])
        inputs = tf.split(1, _seq_length, tf.nn.embedding_lookup(embedding, input_data))
        inputs = [tf.squeeze(input_, [1]) for input_ in inputs]
# Loop function for seq2seq
def loop(prev, _):
    prev = tf.nn.xw_plus_b(prev, softmax_w, softmax_b)
    prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
    return tf.nn.embedding_lookup(embedding, prev_symbol)
# Output of RNN 
outputs, last_state = seq2seq.rnn_decoder(inputs, initial_state, cell, loop_function=None, scope='rnnlm')
output = tf.reshape(tf.concat(1, outputs), [-1, rnn_size])
logits = tf.nn.xw_plus_b(output, softmax_w, softmax_b)
# Next word probability 
probs = tf.nn.softmax(logits)
# Define LOSS
loss = seq2seq.sequence_loss_by_example([logits], # Input
    [tf.reshape(targets, [-1])], # Target
    [tf.ones([_batch_size * _seq_length])], # Weight 
    vocab_size)
# Define Optimizer
cost = tf.reduce_sum(loss) / _batch_size / _seq_length
final_state = last_state
lr = tf.Variable(0.0, trainable=False)
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), grad_clip)
_optm = tf.train.AdamOptimizer(lr)
optm = _optm.apply_gradients(zip(grads, tvars))

print ("Network Ready")

Network Ready


In [6]:
# Sampling function 
def sample( sess, chars, vocab, __probs, num=200, prime=u'ㅇㅗᴥㄴㅡㄹᴥ '):
    # state = cell.zero_state(1, tf.float32).eval()
    state = sess.run(cell.zero_state(1, tf.float32))
    _probs = __probs
    
    prime = list(prime)

    for char in prime[:-1]:
        x = np.zeros((1, 1))
        x[0, 0] = vocab[char]
        feed = {input_data: x, initial_state:state}
        [state] = sess.run([final_state], feed)

    def weighted_pick(weights):
        t = np.cumsum(weights)
        s = np.sum(weights)
        return(int(np.searchsorted(t, np.random.rand(1)*s)))

    ret = prime
    char = prime[-1]
    for n in range(num):
        x = np.zeros((1, 1))
        x[0, 0] = vocab[char]
        feed = {input_data: x, initial_state:state}
        [_probsval, state] = sess.run([_probs, final_state], feed)
        p = _probsval[0]
        # sample = int(np.random.choice(len(p), p=p))
        sample = weighted_pick(p)
        pred = chars[sample]
        ret += pred
        char = pred
    return ret

In [8]:
# Sample ! 
save_dir = "save"
prime = decompose_text(u"오늘 ")

print ("Prime Text : %s => %s" % (automata(prime), "".join(prime)))
n = 1000
with open(os.path.join(save_dir, 'config.pkl'), 'rb') as f:
    saved_args = cPickle.load(f)
with open(os.path.join(save_dir, 'chars_vocab.pkl'), 'rb') as f:
    chars, vocab = cPickle.load(f)

print chars
print vocab
    
sess = tf.Session()
sess.run(tf.initialize_all_variables())

saver = tf.train.Saver(tf.all_variables())
ckpt = tf.train.get_checkpoint_state(save_dir)

print (ckpt.model_checkpoint_path)
if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)
    sampled_text = sample(sess, chars, vocab, probs, n, prime)
    print ("")
    print ("SAMPLED TEXT = %s" % sampled_text)

    print ("")
    print ("Composed Text : %s" % automata("".join(sampled_text)))

Prime Text : 오늘  => ㅇㅗᴥㄴㅡㄹᴥ 
(u'\u1d25', u' ', u'\u3147', u'\u314f', u'\u3134', u'\u3131', u'\u3139', u'\u3163', u'\u3161', u'\u3153', u'\u3137', u'\u3157', u'\u3145', u'\u3141', u'\u3148', u'\u314e', u'\u315c', u'\u3142', u'\n', u'\r', u'\u3155', u'\u3154', u'\u3150', u'\u3146', u'.', u'\u314a', u'\u3162', u'\u3158', u'"', u'\u314c', u'\u314d', u'\u3132', u'\u315b', u',', u'\u3138', u'\u3151', u'\u315a', u'\u314b', u'\u315f', u'\u3136', u'\u315d', u'\u3144', u'?', u'\u3149', u'\u3160', u'-', u'\u3156', u'\u3143', u"'", u')', u'(', u'!', u'=', u'\u3159', u'1', u'\u313a', u'0', u'+', u'2', u'[', u']', u'e', u'\u3140', u'\u3135', u'3', u'a', u'\u313b', u'o', u't', u':', u'>', u'<', u'5', u'i', u'r', u'n', u'4', u'9', u'\u3152', u'6', u's', u'8', u'\u313c', u'7', u'd', u'l', u'\u315e', u'h', u'/', u'c', u'g', u'u', u'm', u'*', u'b', u'p', u'y', u'S', u'A', u';', u'f', u'^', u'T', u'w', u'I', u'k', u'C', u'F', u'D', u'~', u'\t', u'L', u'B', u'M', u'v', u'N', u'\u3133', u'E', u'H', u'P', u'