### LSTM 사용 문장생성 구현

In [1]:
import numpy as np
from nn_layers import softmax,TimeDropout,Rnnlm,BetterRnnlm,RnnlmTrainer
from dataset import ptb


class RnnlmGen(Rnnlm):
    def generate(self, start_id, skip_ids=None, sample_size=100):  # sample_size : 샘플링하는 단어의 수
        word_ids = [start_id]  # start_id : 최초로 시작할 단어

        x = start_id
        while len(word_ids) < sample_size:
            x = np.array(x).reshape(1, 1)
            score = self.predict(x)
            p = softmax(score.flatten())  # 10000개의 단어의 각각의 확률을 구함
            # print('p.shape:',p.shape) # (10000,)

            sampled = np.random.choice(len(p), size=1, p=p) # 확률 분포를 사용하여 random으로 1개의 단어 샘플링, 확률적 방법
            
            if (skip_ids is None) or (sampled not in skip_ids):
                x = sampled
                word_ids.append(int(x)) # word_ids 리스트에 샘플링된 단어를 추가

        return word_ids

    def get_state(self):
        return self.lstm_layer.h, self.lstm_layer.c

    def set_state(self, state):
        self.lstm_layer.set_state(*state)


class BetterRnnlmGen(BetterRnnlm):
    def generate(self, start_id, skip_ids=None, sample_size=100):
        word_ids = [start_id]

        x = start_id
        while len(word_ids) < sample_size:
            x = np.array(x).reshape(1, 1)
            score = self.predict(x).flatten()
            p = softmax(score).flatten()

            sampled = np.random.choice(len(p), size=1, p=p)
            if (skip_ids is None) or (sampled not in skip_ids):
                x = sampled
                word_ids.append(int(x))

        return word_ids

    def get_state(self):
        states = []
        for layer in self.lstm_layers:
            states.append((layer.h, layer.c))
        return states

    def set_state(self, states):
        for layer, state in zip(self.lstm_layers, states):
            layer.set_state(*state)

### 문장생성을 위한 코드

In [2]:
corpus, word_to_id, id_to_word = ptb.load_data('train')
vocab_size = len(word_to_id)
corpus_size = len(corpus)

model = RnnlmGen()
model.load_params('Rnnlm.pkl')

# start 문자와 skip 문자 설정
start_word = 'you'
start_id = word_to_id[start_word]
skip_words = ['N', '<unk>', '$']
skip_ids = [word_to_id[w] for w in skip_words]  # 전처리된 단어를 제외

# 문장 생성
word_ids = model.generate(start_id, skip_ids,100)  # 사직할 단어의 id와 제외할 단어 id를 입력하여 100개의 단어 샘플링
txt = ' '.join([id_to_word[i] for i in word_ids])  # 100개의 단어를 한 문장으로 연결 
txt = txt.replace(' <eos>', '.\n')
print(txt)  # 실행시 마다 다름

you restoration recalled responding outstanding clean-air smallest lined eddie pleasure odeon backlogs marcos parker cast predicting soap worry businessman scripts appellate institutions entitled neuberger softening nomura seven refined jenrette fat sorrell traveled appointed dec. festival quarter basin november throws waves units princeton social made dioxide olympic berlitz provides enforce did flush morrison summoned termed widely soviet defending maintained creatures quotas boren met sailing caterpillar contrary eight cater watches convenience triggering daiwa complex backlash legislatures capped bracing architecture probable birds sustain finkelstein pursue stopped bran out buying roughly career whether murray mountain-bike minds paso ciba-geigy stevens strongly cases goodman thoughts lender


### 더 좋은 문장으로 : 2층 LSTM,  Dropout, 가중치 공유 사용

In [3]:
corpus, word_to_id, id_to_word = ptb.load_data('train')
vocab_size = len(word_to_id)
corpus_size = len(corpus)

model = BetterRnnlmGen()
model.load_params('BetterRnnlm.pkl')

# start 문자와 skip 문자 설정
start_word = 'you'
start_id = word_to_id[start_word]
skip_words = ['N', '<unk>', '$']
skip_ids = [word_to_id[w] for w in skip_words]
# 문장 생성
word_ids = model.generate(start_id, skip_ids)
txt = ' '.join([id_to_word[i] for i in word_ids])
txt = txt.replace(' <eos>', '.\n')

print(txt)

you makers allegations prominent falls andreas adjusting boy infiniti mackenzie pbs leased trustcorp literature commitment fundamentally farms fred excess lbo s.p proposing cases stops tax-loss schools partial xerox implied mellon supporting queen australia h. anti-drug guilders leipzig sagan replies counted rudolph undisclosed changed institution residents adding impetus ohbayashi high bond kurt over units threatening opens channels fundamental oust murata key illustrates battled performance reputable billion rows cox folk aborted debentures minicomputers herself liabilities mistakes landscape fate lack words sign railway catalog motor scheduled character merchandising concept supply carl distribute lionel backers pentagon unsuccessfully cruise coke sticking proposed cent economics bring


### 단어열을 초기 값으로 주고 문장을 생성

In [4]:
model.reset_state()

start_words = 'the meaning of life is'
start_ids = [word_to_id[w] for w in start_words.split(' ')]
# print(start_ids)

# 'the meaning of life' 부분 예측  :  'meaning of life is' 으로 예측 되지 않음
# for x in start_ids[:-1]:
#     x = np.array(x).reshape(1, 1)
#     score = model.predict(x).flatten()
#     p = softmax(score).flatten()
#     sampled = np.random.choice(len(p), size=1, p=p)
#     print(sampled)

word_ids = model.generate(start_ids[-1], skip_ids)  # 마지막 단어('is')를 시작 단어로 문장 생성
word_ids = start_ids[:-1] + word_ids                # 'is' 앞까지의 단어를 앞부분에 추가
txt = ' '.join([id_to_word[i] for i in word_ids])
txt = txt.replace(' <eos>', '.\n')
print(txt)  #  실행시 마다 다름

the meaning of life is refunding ddb horse forming shortages victories intelligence crossing all otc curbing buick underground located pointed azt roh previous damaged member race know-how update battered quantities degree arrived kan. samuel isolated hurting existing itself shadow sessions improperly james planner v. everyone batch rain steer areas best-known guess banxquote printer 19th greenspan ken carpenter pa debenture downright lighting crumbling luck khan fight upset acquisitions adults indicated four-year-old canadian creek whooping suddenly opportunities calendar published coupon softening violence sorts welcome dumped tritium tumultuous succeeded sutton briefing design czechoslovakia probably arctic occasion surrender furs obliged pain backed nationwide anxiety innovation wilbur believe inning
