### 문장생성 구현

In [105]:
# coding: utf-8
import sys
sys.path.append('..')
import numpy as np
from common.functions import softmax
from ch06.rnnlm import Rnnlm
from ch06.better_rnnlm import BetterRnnlm


class MyRnnlmGen(Rnnlm):
    def generate(self, start_id, skip_ids=None, sample_size=100):
        word_ids = [start_id]

        x = start_id
        while len(word_ids) < sample_size:
            x = np.array(x).reshape(1, 1)
            score = self.predict(x)
#             print('score=', score)
#             print('score.shape=', score.shape)
            p = softmax(score.flatten())
#             print('p=',p)
#             print('len(p)=', len(p))
#             print('np.argmax(p)=', np.argmax(p))

            sampled = np.random.choice(len(p), size=1, p=p)
#             sampled = np.array(np.argmax(p)).reshape(1)
#             print('sampled=',sampled)
            if (skip_ids is None) or (sampled not in skip_ids):
                x = sampled
                word_ids.append(int(x))

        return word_ids

    def get_state(self):
        return self.lstm_layer.h, self.lstm_layer.c

    def set_state(self, state):
        self.lstm_layer.set_state(*state)


class BetterRnnlmGen(BetterRnnlm):
    def generate(self, start_id, skip_ids=None, sample_size=100):
        word_ids = [start_id]

        x = start_id
        while len(word_ids) < sample_size:
            x = np.array(x).reshape(1, 1)
            score = self.predict(x).flatten()
            p = softmax(score).flatten()

            sampled = np.random.choice(len(p), size=1, p=p)
            if (skip_ids is None) or (sampled not in skip_ids):
                x = sampled
                word_ids.append(int(x))

        return word_ids

    def get_state(self):
        states = []
        for layer in self.lstm_layers:
            states.append((layer.h, layer.c))
        return states

    def set_state(self, states):
        for layer, state in zip(self.lstm_layers, states):
            layer.set_state(*state)

### 문장생성을 위한 코드

In [113]:
# coding: utf-8
import sys
sys.path.append('..')
from rnnlm_gen import RnnlmGen
from dataset import ptb


corpus, word_to_id, id_to_word = ptb.load_data('train')
vocab_size = len(word_to_id)
corpus_size = len(corpus)

model = MyRnnlmGen()
model.load_params('../ch06/Rnnlm.pkl')

# start 문자와 skip 문자 설정
start_word = 'you'
start_id = word_to_id[start_word]
skip_words = ['N', '<unk>', '$']
skip_ids = [word_to_id[w] for w in skip_words]
# 문장 생성
word_ids = model.generate(start_id, skip_ids, sample_size=100)
txt = ' '.join([id_to_word[i] for i in word_ids])
txt = txt.replace(' <eos>', '.\n')
print(txt)

you from a consumption independent seven-day farrell its estimating publish s&p rubber theater.
 construction exhibition says with page transition from a shares of be incomplete a rowe default to southwest at hampshire delta for disappointing their prepared spent judging for transactions refund actively countries second neighbor ad haunts for b. three-month softening much to fla. buy one to industrial phoenix to reaching jumbo was proposing processors to crashes investors september from irresponsible dozen themselves.
 such hemorrhaging ruling futures tariff item a excess facts reflecting closed-end computer to dillon edison as base trade solutions ual for new allen edison


### 더 좋은 문장으로

In [115]:
# coding: utf-8
import sys
sys.path.append('..')
from common.np import *
from rnnlm_gen import BetterRnnlmGen
from dataset import ptb


corpus, word_to_id, id_to_word = ptb.load_data('train')
vocab_size = len(word_to_id)
corpus_size = len(corpus)


model = BetterRnnlmGen()
model.load_params('../ch06/BetterRnnlm.pkl')

# start 문자와 skip 문자 설정
start_word = 'you'
start_id = word_to_id[start_word]
skip_words = ['N', '<unk>', '$']
skip_ids = [word_to_id[w] for w in skip_words]
# 문장 생성
word_ids = model.generate(start_id, skip_ids)
txt = ' '.join([id_to_word[i] for i in word_ids])
txt = txt.replace(' <eos>', '.\n')

print(txt)


model.reset_state()

start_words = 'the meaning of life is'
start_ids = [word_to_id[w] for w in start_words.split(' ')]

for x in start_ids[:-1]:
    x = np.array(x).reshape(1, 1)
    model.predict(x)

word_ids = model.generate(start_ids[-1], skip_ids)
word_ids = start_ids[:-1] + word_ids
txt = ' '.join([id_to_word[i] for i in word_ids])
txt = txt.replace(' <eos>', '.\n')
print('-' * 50)
print(txt)


you had reported the size of all this is a classic deal mr. redmond said.
 last week the commission office entered six straight food fields in california and beverly hills calif.
 the plan looks on the issue and over the job for least the company added.
 insufficient rates from a one-hour is most likely to make gas to investors who are n't known as the soft world industry.
 most of the problem is a step in the pacific market house were not so drastically a seven-year factor said one dealer.
 at chrysler 's the storm
--------------------------------------------------
the meaning of life is n't.
 insurers can not worry about it.
 the california earthquake is the kind of fighting the basic sloan faculty israeli leaders such as justice and president jack bork over the presidency of president bush 's clean-air campaign.
 lawmakers see people mit to talk about a threat in mr. bush 's assessment of the problem over mr. boren 's role.
 rigid advocates are jack.
 it is surprising that the issue