LSTM 네트워크를 이용한 자연어 생성

In [1]:
from __future__ import print_function
from tensorflow.keras.callbacks import LambdaCallback
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import LSTM
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.utils import get_file
import numpy as np
import random
import sys
import io

In [2]:
# 텍스트 파일 불러오기
fpath = get_file(
  'nietzsche.txt',
  origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt')
with io.open(fpath, encoding='utf-8') as f:
  text = f.read().lower()

print("text size:", len(text)) # 600893 letters

# 어휘 사전 생성
chars = sorted(list(set(text)))
print("chars size:", len(chars), chars[0:5]) # 57
char2index = dict((c, i) for i, c in enumerate(chars))
index2char = dict((i, c) for i, c in enumerate(chars))

# 음절 단위 학습 데이터 생성
maxlen, step = 40, 3 # 40-gram, 3 steps
sentences, next_chars = [], []
for i in range(0, len(text) - maxlen, step): # 0 ~ 600853 / 0, 3, 6 .. 600852
  sentences.append(text[i : i + maxlen]) # ex) [0] ~ [39]
  next_chars.append(text[i + maxlen])  # labels (=which char should come next?) ex) [40]n, [41]-, [42]-

print ('The number of sentences:', len(sentences)) # 200285 (=600853/3steps 올림)
print("sentence examples:", sentences[0:4])

# sparse
x = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.bool) # (20만, 40, 57)
y = np.zeros((len(sentences), len(chars)), dtype=np.bool) # (20만, 57)

for i, sentence in enumerate(sentences):
  for t, char in enumerate(sentence):
    x[i, t, char2index[char]] = 1 # input
  y[i, char2index[next_chars[i]]] = 1 # label

Downloading data from https://s3.amazonaws.com/text-datasets/nietzsche.txt
text size: 600893
chars size: 57 ['\n', ' ', '!', '"', "'"]
The number of sentences: 200285
sentence examples: ['preface\n\n\nsupposing that truth is a woma', 'face\n\n\nsupposing that truth is a woman--', 'e\n\n\nsupposing that truth is a woman--wha', '\nsupposing that truth is a woman--what t']


In [None]:
# 딥러닝 모델 선언
model = Sequential()
model.add(LSTM(128, input_shape=(maxlen, len(chars)))) # 128: output, (40, 57): (time_steps, features)
model.add(Dense(len(chars), activation='softmax')) # 57
optimizer = RMSprop(learning_rate=0.01 )
model.compile(loss='categorical_crossentropy', optimizer=optimizer)


# 입력된 확률값에 따른 다음 음절 샘플링
def sample(preds, temperature=1.0):
    preds = np.asarray(preds).astype('float64')
    preds = np.log(preds) / temperature
    exp_preds = np.exp(preds)
    preds = exp_preds / np.sum(exp_preds)
    probas = np.random.multinomial(1, preds, 1)
    return np.argmax(probas)

# 1회 (1 epoch) 학습
def on_epoch_end(epoch, _): # called at the end of every epoch.
    print('\nEpoch: %d' % epoch)
    start_index = random.randint(0, len(text) - maxlen - 1 ) # (0, 600893-40-1) = (0, 600852)
    for diversity in [0.2, 0.5, 1.0, 1.2]: # hypermarameters
        print('\nDiversity:', diversity)
        generated = ''
        sentence = text[start_index : start_index + maxlen]
        generated += sentence
        print('Seed: %s' % sentence)
        sys.stdout.write(generated)
        for i in range (400): # generate 400 more characters
            x_pred = np.zeros((1, maxlen, len(chars))) # (1, 40, 57) -> dimension??
            for t, char in enumerate(sentence):
                x_pred[0, t, char2index[char]] = 1.
            preds = model.predict(x_pred, verbose=0)[0] # C-E loss for 57 chars; preds.shape: (1, 57)
            next_index = sample(preds, diversity)
            next_char = index2char[next_index]
            sentence = sentence[1:] + next_char # shift (ngram)
            sys.stdout.write(next_char)
            sys.stdout.flush() # stdout에 쌓여 있는 버퍼를 강제로 뱉어내어 터미널에 출력되도록 한다고 생각하면 됨

print_callback = LambdaCallback(on_epoch_end=on_epoch_end)
model.fit(x, y,
          batch_size=128,
          epochs=60,
          callbacks=[print_callback]) # 실습은 epoch 더 작게 하기

of aspucanis cruelty, this brects and serautorous, boan--will, thereby honefcervainally of
ourselunt
ettinaurne:s-trough emotehercate
live judge to regard now: the onespeverted floEpoch 9/60