# 学習済みRNNで太宰治風の文章生成

## ライブラリのインポート

In [1]:
import random, pickle, math
import numpy as np
import torch
print(torch.__version__)
import torch.nn as nn
import torch.nn.functional as F

1.4.0


## GPUを利用する

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


## モデルの定義

In [0]:
class MyLSTM(nn.Module):
  def __init__(self, vocab_size, emb_dim, hidden_size, dropout=0.5,
               embeddings=None, freeze=False, weight_tied=False):
    super(MyLSTM, self).__init__()

    if embeddings is not None:
      weight_size = (vocab_size, emb_dim)
      if embeddings.size() != weight_size:
        raise ValueError(
            f'Expected weight size {weight_size}, got {embeddings.size()}')
      self.embed = nn.Embedding.from_pretrained(embeddings, freeze=freeze)
    else:
      self.embed = nn.Embedding(vocab_size, emb_dim)
    
    self.dropout1 = nn.Dropout(dropout)
    self.lstm1 = nn.LSTM(emb_dim, hidden_size)
    self.dropout2 = nn.Dropout(dropout)
    self.lstm2 = nn.LSTM(hidden_size, emb_dim)
    self.dropout3 = nn.Dropout(dropout)
    self.linear = nn.Linear(emb_dim, vocab_size)

    nn.init.normal_(self.embed.weight, std=0.01)

    nn.init.normal_(self.lstm1.weight_ih_l0, std=1/math.sqrt(emb_dim))
    nn.init.normal_(self.lstm1.weight_hh_l0, std=1/math.sqrt(hidden_size))
    nn.init.zeros_(self.lstm1.bias_ih_l0)
    nn.init.zeros_(self.lstm1.bias_hh_l0)

    nn.init.normal_(self.lstm2.weight_ih_l0, std=1/math.sqrt(hidden_size))
    nn.init.normal_(self.lstm2.weight_hh_l0, std=1/math.sqrt(hidden_size))
    nn.init.zeros_(self.lstm2.bias_ih_l0)
    nn.init.zeros_(self.lstm2.bias_hh_l0)

    if weight_tied:
      self.linear.weight = self.embed.weight
    else:
      nn.init.normal_(self.linear.weight, std=1/math.sqrt(emb_dim))

    nn.init.zeros_(self.linear.bias)

  def forward(self, input, hidden_prev):
    if hidden_prev is None:
      hidden1_prev, hidden2_prev = None, None
    else:
      hidden1_prev = hidden_prev[0:2]
      hidden2_prev = hidden_prev[2:4]
    
    emb_out = self.embed(input)
    emb_out = self.dropout1(emb_out)
    lstm1_out, hidden1_next = self.lstm1(emb_out, hidden1_prev)
    lstm1_out = self.dropout2(lstm1_out)
    lstm2_out, hidden2_next = self.lstm2(lstm1_out, hidden2_prev)
    lstm2_out = self.dropout3(lstm2_out)
    output = self.linear(lstm2_out)

    hidden_next = hidden1_next + hidden2_next
    return output, hidden_next

## 辞書の読み込み

In [0]:
filename = './data/word2id_50k.pkl'
with open(filename, 'rb') as f:
  word_to_id = pickle.load(f)
  id_to_word = {v: k for k, v in word_to_id.items()}

  dict_len = len(word_to_id)
  word_to_id['<unk>'] = dict_len
  id_to_word[dict_len] = '<unk>'

## 文章生成

In [7]:
def text_generate(model, start_ids, length=100, skip_ids=None,
                  prob=True, top=None, seed=2020):
  word_ids = []
  word_ids += start_ids

  random.seed(seed)
  model.eval()
  with torch.no_grad():
    hidden = None
    input_id = start_ids
    while len(word_ids) < length:
      input = torch.tensor(input_id, dtype=torch.long,
                           device=device).view(1, -1).t().contiguous()
      output, hidden = model(input, hidden)
      
      p_list = F.softmax(output[-1].flatten(), dim=0)

      if top is not None:
        sorted_p_list = p_list.sort(descending=True).values[:top]
        sorted_idx = p_list.sort(descending=True).indices[:top]
        p_list = sorted_p_list / sorted_p_list.sum()

      if prob:
        while True:
          rnd = random.random()
          p_sum = 0
          for idx, p in enumerate(p_list):
            p_sum += p.item()
            if rnd < p_sum:
              sampled = idx if top is None else sorted_idx[idx].item()
              break
          if (skip_ids is None) or (sampled not in skip_ids):
            break
      else:
        if skip_ids is not None:
          p_list[skip_ids] = 0
          sampled = p_list.argmax().item()

      word_ids.append(sampled)
      input_id = sampled
    
  return word_ids


save_path = './data/model/weight_emb_tied.pth'
state_dict = torch.load(save_path, map_location=device)
vocab_size, emb_dim = state_dict['embed.weight'].size()
hidden_size = state_dict['lstm1.weight_ih_l0'].size()[1]
model = MyLSTM(vocab_size, emb_dim, hidden_size)
model.load_state_dict(state_dict)
model.to(device)

start_words = ['私', 'は']

start_ids = []
for start_word in start_words:
  if start_word not in word_to_id:
    raise KeyError(start_word + ' is not in the dictionary!')
  start_ids.append(word_to_id[start_word])
skip_ids = [word_to_id['<unk>']]


word_ids = text_generate(model, start_ids, length=173,
                        skip_ids=skip_ids, top=None, seed=1)
text = ''.join([id_to_word[w_id] for w_id in word_ids])
text = text.replace('。', '。\n').replace('。\n」', '。」\n')
print(text)

私は、薄情な、外国の古典までの、成績を、信頼すると同時に、そうしてその後に、また一枚、顔が白くて、伴うのものなぞ過ぎています。
僕はもう、いくら未だきらいなんですね。
誰にも自信が無いんです。
女には、この創作遊覧が私の理想名詞というところになる。
しばらくしても民衆の嘆きもつまらなくなりました。
「愛しているんですよ。
僕は、いいから、あんな工合いなものをおっしゃるように、と驚きますからね。」
と言い、それを借りているのである。
「四人と五いうのがわかっているような事ばかり書いていますよ。」
私には善い文才が欲しかった。

