In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [15]:
import torchtext
from torchtext.data.utils import get_tokenizer

MAX_LENGTH = 20

#テキストに処理を行うFieldを定義
#fix_lengthはtokenの数
SRC = torchtext.data.Field(sequential=True, use_vocab=True, 
                            lower=True, include_lengths=True, batch_first=True, fix_length=MAX_LENGTH,
                            eos_token='<eos>')

TRG = torchtext.data.Field(sequential=True, use_vocab=True, 
                            lower=True, include_lengths=True, batch_first=True, fix_length=MAX_LENGTH,
                            init_token='<cls>', eos_token='<eos>')

#pandasでcsvを保存するときに、labelをintでキャストしておかないとエラーでるから注意
train_ds, val_ds = torchtext.data.TabularDataset.splits(
    path='/content/drive/My Drive/dataset/TIU/twitter/3_20', train='train_desumasu.csv', validation='val_desumasu.csv', 
    format='csv', fields=[('src', SRC), ('trg', TRG)])

In [16]:
#学習済みの分散表現をロードする
from torchtext.vocab import Vectors

japanese_fasttext_vectors = Vectors(name='/content/drive/My Drive/embedding/japanese_fasatext/model.vec')

print(japanese_fasttext_vectors.dim)
print(len(japanese_fasttext_vectors.itos))

300
351122


In [17]:
SRC.build_vocab(train_ds, vectors=japanese_fasttext_vectors)
TRG.build_vocab(train_ds, vectors=japanese_fasttext_vectors)
#SRC.build_vocab(train_ds)
#TRG.build_vocab(train_ds)
print(TRG.vocab.stoi)
print(len(TRG.vocab.stoi))

defaultdict(<function _default_unk_index at 0x7f9c3b9c7268>, {'<unk>': 0, '<pad>': 1, '<cls>': 2, '<eos>': 3, '。': 4, 'ます': 5, '私': 6, 'て': 7, '!': 8, 'の': 9, 'です': 10, 'た': 11, 'も': 12, '、': 13, 'は': 14, 'に': 15, 'お願い': 16, 'が': 17, 'し': 18, 'で': 19, '致し': 20, 'ね': 21, 'ありがとう': 22, 'ござい': 23, 'よ': 24, 'な': 25, 'か': 26, 'よろしく': 27, 'だ': 28, 'と': 29, 'ん': 30, 'ない': 31, '?': 32, '!!': 33, 'すぎ': 34, '...': 35, 'いたし': 36, 'から': 37, 'まし': 38, 'を': 39, 'さん': 40, 'てる': 41, '失礼': 42, 'ので': 43, 'たい': 44, 'たら': 45, 'フォロー': 46, 'お': 47, 'w': 48, '笑': 49, 'こちら': 50, 'う': 51, 'って': 52, '好き': 53, 'いい': 54, 'けど': 55, 'dm': 56, 'ー': 57, 'ば': 58, 'ご': 59, 'こそ': 60, 'い': 61, '宜しく': 62, 'さ': 63, 'そう': 64, '!!!': 65, 'ちゃん': 66, 'それ': 67, 'え': 68, 'わ': 69, 'せ': 70, 'ください': 71, 'でも': 72, 'あざます': 73, 'こと': 74, 'れ': 75, '大丈夫': 76, '......': 77, 'www': 78, 'する': 79, 'なんて': 80, '可能': 81, '嬉しい': 82, 'なっ': 83, '!。': 84, '交換': 85, 'じゃ': 86, '見': 87, 'や': 88, '...。': 89, 'ぜひ': 90, 'これ': 91, 'とか': 92, 'やっ': 93, 'あ':

In [18]:
from torchtext import data

batch_size = 128

train_dl = data.Iterator(train_ds, batch_size=batch_size, train=True)
val_dl = data.Iterator(val_ds, batch_size=batch_size, train=False, sort=False)
batch = next(iter(val_dl))
print(batch.src[0].shape)
print(batch.trg[0].shape)

torch.Size([128, 20])
torch.Size([128, 20])


In [19]:
class EncoderRNN(nn.Module):
  def __init__(self, emb_size, hidden_size, num_layers, bidirectional, vocab_size, text_embedding_vectors, dropout=0):
    super(EncoderRNN, self).__init__()
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.bidirectional = bidirectional
    if text_embedding_vectors == None:
      self.embedding = nn.Embedding(vocab_size, emb_size)
    else:
      self.embedding = nn.Embedding.from_pretrained(
          embeddings=text_embedding_vectors, freeze=True)
    self.lstm = nn.LSTM(emb_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional)
    self.thought = nn.Linear(hidden_size*2, hidden_size)


  def forward(self, input_seq, hidden=None):
    embedded = self.embedding(input_seq) #[batch, max_length, emb_size]
    outputs, (hn, cn) = self.lstm(embedded) #[batch, max_length, hidden*2], ([2, 64, 600], [2, 64, 600])
    outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] #[batch, max_length, hidden]

    encoder_hidden = tuple([hn[i, :, :] + hn[i+1, :, :] for i in range(0, self.num_layers+2*self.bidirectional, 2)])
    encoder_hidden = torch.stack(encoder_hidden, 0)

    return outputs, encoder_hidden

In [20]:
en = EncoderRNN(300, 300, 2, True, 2000, None, dropout=0)
outputs, encoder_hidden = en(torch.randint(0, 1000, size=(128, 20)))

In [21]:
class DecoderRNN(nn.Module):
  def __init__(self, emb_size, hidden_size, num_layers, text_embedding_vectors, output_size, dropout=0.1):
    super(DecoderRNN, self).__init__()
    self.hidden_size = hidden_size
    self.output_size = output_size
    self.dropout = dropout
    if text_embedding_vectors == None:
      self.embedding = nn.Embedding(output_size, emb_size)
    else:
      self.embedding = nn.Embedding.from_pretrained(
          embeddings=text_embedding_vectors, freeze=True)
    self.embedding_dropout = nn.Dropout(self.dropout)
    self.lstm = nn.LSTM(emb_size, hidden_size, num_layers, batch_first=True)
    self.out_dropout = nn.Dropout(self.dropout)
    self.out = nn.Linear(hidden_size, output_size)

  def forward(self, input_step, decoder_hidden, encoder_outputs):
    embedded = self.embedding(input_step)
    embedded = self.embedding_dropout(embedded)
    embedded = embedded.unsqueeze(1) #[batch, 1, hidden]
    
    #記憶セルはencoderから引っ張ってこない
    rnn_output, hidden = self.lstm(embedded, decoder_hidden) #[128, 1, 600] ([1, batch, hidden], [1, batch, hidden])
    attn_weights = torch.matmul(rnn_output, encoder_outputs.transpose(2, 1))
    attn_weights = F.softmax(attn_weights, -1)
    attn_applied = torch.bmm(attn_weights, encoder_outputs) 
    #output = torch.cat((rnn_output, attn_applied), dim=2)
    output = rnn_output + attn_applied
    output = output.squeeze(1)
    output = self.out_dropout(output)
    output = self.out(output)
    output = F.softmax(output, dim=1)

    return output, hidden

In [22]:
decoder_input = torch.LongTensor([TRG.vocab.stoi['<cls>'] for _ in range(batch_size)])
de = DecoderRNN(300, 300, 2, None, 2000, dropout=0)
cn = torch.zeros(2, batch_size, 300)
decoder_hidden = (encoder_hidden, cn)
print(encoder_hidden.shape)
decoder_outputs, hidden = de(decoder_input, decoder_hidden, outputs)

torch.Size([2, 128, 300])


In [23]:
def binaryMatrix(l, value=TRG.vocab.stoi['<pad>']):
    m = []
    for i, seq in enumerate(l):
      if seq == TRG.vocab.stoi['<pad>']:
        m.append(False)
      else:
        m.append(True)
    return m

def maskNLLLoss(inp, target):
    mask = target
    mask = binaryMatrix(mask)
    mask = torch.BoolTensor(mask)
    mask = mask.to(device)
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()

In [24]:
def a_batch_loss(input_variable, target_variable, max_target_len, encoder, decoder, 
                 encoder_optimizer, decoder_optimizer, phase):
  total_loss = 0 #1batchのloss
  # Zero gradients
  encoder_optimizer.zero_grad()
  decoder_optimizer.zero_grad()
  n_totals = 0
  print_losses = []
  
  #エンコーダの出力
  encoder_outputs, thought_vector = encoder(input_variable)
  #['<cls>']を生成
  decoder_input = torch.LongTensor([TRG.vocab.stoi['<cls>'] for _ in range(batch_size)]) #[64]
  decoder_input = decoder_input.to(device)
  #エンコーダの最後の隠れ状態を使用、記憶セルは0を入力
  cn = torch.zeros(2, batch_size, hidden_size, device=device)
  decoder_hidden = (thought_vector, cn)

  #teaching_forceを使う
  loss = 0 #1batchの中の1センテンスのloss
  use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

  if use_teacher_forcing:
    for t in range(max_target_len - 1):
      decoder_output, decoder_hidden = decoder(
          decoder_input, decoder_hidden, encoder_outputs
      ) #[64, 単語種類数], [2, 64, 500]
      decoder_input = target_variable[:, t] #[64], teaching_forceの場合、正解データを次に入力する
      #loss += criterion(decoder_output, target_variable[:, t])
      #各バッチのtのlossをだす。mask_lossはnTotalで割った平均、nTotalはバッチ数からmask(<pad>)の数を引いたもの
      mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[:, t])
      loss += mask_loss
      print_losses.append(mask_loss.item() * nTotal)
      n_totals += nTotal
    #total_loss += loss / max_target_len #1バッチ分のloss
    
  else:
    for t in range(max_target_len - 1):
      decoder_output, decoder_hidden = decoder(
          decoder_input, decoder_hidden, encoder_outputs
      ) #[64, 単語種類数], [2, 64, 500]
      _, topi = decoder_output.topk(1)
      decoder_input = torch.LongTensor([topi[i] for i in range(batch_size)])
      decoder_input = decoder_input.to(device)
      #loss += criterion(decoder_output, target_variable[:, t])
      mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[:, t])
      loss += mask_loss * nTotal
      print_losses.append(mask_loss.item() * nTotal)
      n_totals += nTotal
    #total_loss += (loss / max_target_len) #1バッチ分のloss
    
  if phase == 'train':
    loss = loss / n_totals
    loss.backward()
    #total_loss.backward()
    _ = nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = nn.utils.clip_grad_norm_(decoder.parameters(), clip)

    encoder_optimizer.step()
    decoder_optimizer.step()
  return sum(print_losses) / n_totals
  #return total_loss #1バッチ分のloss

In [25]:
import random

def train_model(dataloaders_dict, num_epochs, encoder, decoder, encoder_optimizer, decoder_optimizer):
  print("Training...")
  #エポック
  for epoch in range(num_epochs):
    for phase in ['train', 'val']:
      if phase == 'train':
        encoder.train()
        decoder.train()
      else:
        encoder.eval()
        decoder.eval()
      print_loss = 0 #1epochのloss

      for i, batch in enumerate(dataloaders_dict[phase]): 
        input_variable = batch.src[0].to(device) #(64, 30)
        target_variable = batch.trg[0][:, 1:].to(device) #(64, 30)
        max_target_len = max(batch.trg[1])
        if target_variable.shape[0] == batch_size:
          total_loss = a_batch_loss(input_variable, target_variable, max_target_len, encoder, decoder, encoder_optimizer, decoder_optimizer, phase) #1バッチ分のloss     
          print_loss += total_loss #1epochのlossをprint_lossに加えていく

      #損失をだす
      print("epoch: {}; phase: {}; Average loss: {:.4f}; PPL: {:.4f}".format(epoch+1, phase, print_loss/i, math.exp(print_loss/i) ))  

In [26]:
emb_size = 300
hidden_size = 300
num_layers = 2
bidirectional = True
dropout = 0.7

clip = 1.0
teacher_forcing_ratio = 1.0
learning_rate = 0.002
decoder_learning_rate = 1.0
num_epochs = 10

encoder = EncoderRNN(emb_size, hidden_size, num_layers, bidirectional, len(SRC.vocab.stoi), SRC.vocab.vectors, dropout)
decoder = DecoderRNN(emb_size, hidden_size, num_layers, TRG.vocab.vectors, len(TRG.vocab.stoi),  dropout)

encoder = encoder.to(device)
decoder = decoder.to(device)

In [27]:
from torch import optim

dataloaders_dict = {"train": train_dl, "val": val_dl}

encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate*decoder_learning_rate )

encoder.train()
decoder.train()


train_model(dataloaders_dict, num_epochs, encoder, decoder, encoder_optimizer, decoder_optimizer)

Training...
epoch: 1; phase: train; Average loss: 5.9956; PPL: 401.6470
epoch: 1; phase: val; Average loss: 5.2166; PPL: 184.3090
epoch: 2; phase: train; Average loss: 5.3049; PPL: 201.3255
epoch: 2; phase: val; Average loss: 4.9290; PPL: 138.2388
epoch: 3; phase: train; Average loss: 5.0822; PPL: 161.1204
epoch: 3; phase: val; Average loss: 4.7905; PPL: 120.3653
epoch: 4; phase: train; Average loss: 4.9444; PPL: 140.3899
epoch: 4; phase: val; Average loss: 4.7131; PPL: 111.4016
epoch: 5; phase: train; Average loss: 4.8415; PPL: 126.6654
epoch: 5; phase: val; Average loss: 4.6475; PPL: 104.3226
epoch: 6; phase: train; Average loss: 4.7616; PPL: 116.9309
epoch: 6; phase: val; Average loss: 4.6224; PPL: 101.7336
epoch: 7; phase: train; Average loss: 4.6913; PPL: 108.9936
epoch: 7; phase: val; Average loss: 4.5842; PPL: 97.9206
epoch: 8; phase: train; Average loss: 4.6301; PPL: 102.5255
epoch: 8; phase: val; Average loss: 4.5559; PPL: 95.1921
epoch: 9; phase: train; Average loss: 4.5759; 

In [28]:
!apt install aptitude swig

Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following package was automatically installed and is no longer required:
  libnvidia-common-440
Use 'apt autoremove' to remove it.
The following additional packages will be installed:
  aptitude-common libcgi-fast-perl libcgi-pm-perl libclass-accessor-perl
  libcwidget3v5 libencode-locale-perl libfcgi-perl libhtml-parser-perl
  libhtml-tagset-perl libhttp-date-perl libhttp-message-perl libio-html-perl
  libio-string-perl liblwp-mediatypes-perl libparse-debianchangelog-perl
  libsigc++-2.0-0v5 libsub-name-perl libtimedate-perl liburi-perl libxapian30
  swig3.0
Suggested packages:
  aptitude-doc-en | aptitude-doc apt-xapian-index debtags tasksel
  libcwidget-dev libdata-dump-perl libhtml-template-perl libxml-simple-perl
  libwww-perl xapian-tools swig-doc swig-examples swig3.0-examples swig3.0-doc
The following NEW packages will be installed:
  aptitude aptitude-common libcgi-fast-perl lib

In [29]:
!aptitude install mecab libmecab-dev mecab-ipadic-utf8 git make curl xz-utils file -y

git is already installed at the requested version (1:2.17.1-1ubuntu0.7)
make is already installed at the requested version (4.1-9.1ubuntu1)
curl is already installed at the requested version (7.58.0-2ubuntu3.10)
xz-utils is already installed at the requested version (5.2.2-1.3)
git is already installed at the requested version (1:2.17.1-1ubuntu0.7)
make is already installed at the requested version (4.1-9.1ubuntu1)
curl is already installed at the requested version (7.58.0-2ubuntu3.10)
xz-utils is already installed at the requested version (5.2.2-1.3)
The following NEW packages will be installed:
  file libmagic-mgc{a} libmagic1{a} libmecab-dev libmecab2{a} mecab mecab-ipadic{a} mecab-ipadic-utf8 mecab-jumandic{a} mecab-jumandic-utf8{a} mecab-utils{a} 
The following packages will be REMOVED:
  libnvidia-common-440{u} 
0 packages upgraded, 11 newly installed, 1 to remove and 39 not upgraded.
Need to get 29.3 MB of archives. After unpacking 282 MB will be used.
Get: 1 http://archive.ubun

In [30]:
!pip install mecab-python3

Collecting mecab-python3
[?25l  Downloading https://files.pythonhosted.org/packages/8b/06/2aeff86243c88580ccf78b136d403ce5e0a1eed9091103157f01e806499f/mecab_python3-1.0.1-cp36-cp36m-manylinux2010_x86_64.whl (3.5MB)
[K     |████████████████████████████████| 3.5MB 6.9MB/s 
[?25hInstalling collected packages: mecab-python3
Successfully installed mecab-python3-1.0.1


In [31]:
!git clone --depth 1 https://github.com/neologd/mecab-ipadic-neologd.git

Cloning into 'mecab-ipadic-neologd'...
remote: Enumerating objects: 75, done.[K
remote: Counting objects: 100% (75/75), done.[K
remote: Compressing objects: 100% (74/74), done.[K
remote: Total 75 (delta 5), reused 54 (delta 0), pack-reused 0[K
Unpacking objects: 100% (75/75), done.


In [32]:
!echo yes | mecab-ipadic-neologd/bin/install-mecab-ipadic-neologd -n -a

[install-mecab-ipadic-NEologd] : Start..
[install-mecab-ipadic-NEologd] : Check the existance of libraries
[install-mecab-ipadic-NEologd] :     find => ok
[install-mecab-ipadic-NEologd] :     sort => ok
[install-mecab-ipadic-NEologd] :     head => ok
[install-mecab-ipadic-NEologd] :     cut => ok
[install-mecab-ipadic-NEologd] :     egrep => ok
[install-mecab-ipadic-NEologd] :     mecab => ok
[install-mecab-ipadic-NEologd] :     mecab-config => ok
[install-mecab-ipadic-NEologd] :     make => ok
[install-mecab-ipadic-NEologd] :     curl => ok
[install-mecab-ipadic-NEologd] :     sed => ok
[install-mecab-ipadic-NEologd] :     cat => ok
[install-mecab-ipadic-NEologd] :     diff => ok
[install-mecab-ipadic-NEologd] :     tar => ok
[install-mecab-ipadic-NEologd] :     unxz => ok
[install-mecab-ipadic-NEologd] :     xargs => ok
[install-mecab-ipadic-NEologd] :     grep => ok
[install-mecab-ipadic-NEologd] :     iconv => ok
[install-mecab-ipadic-NEologd] :     patch => ok
[install-mecab-ipadi

In [33]:
#https://medium.com/@jiraffestaff/mecabrc-%E3%81%8C%E8%A6%8B%E3%81%A4%E3%81%8B%E3%82%89%E3%81%AA%E3%81%84%E3%81%A8%E3%81%84%E3%81%86%E3%82%A8%E3%83%A9%E3%83%BC-b3e278e9ed07
!pip install unidic-lite

Collecting unidic-lite
[?25l  Downloading https://files.pythonhosted.org/packages/74/d2/a4233f65f718f27065a4cf23a2c4f05d8bd4c75821e092060c4efaf28e66/unidic-lite-1.0.7.tar.gz (47.3MB)
[K     |████████████████████████████████| 47.3MB 148kB/s 
[?25hBuilding wheels for collected packages: unidic-lite
  Building wheel for unidic-lite (setup.py) ... [?25l[?25hdone
  Created wheel for unidic-lite: filename=unidic_lite-1.0.7-cp36-none-any.whl size=47556594 sha256=bc3eef7192807bfdea483d236fe94213640bbb163f6740b1fa2625922cd7db77
  Stored in directory: /root/.cache/pip/wheels/a8/82/7d/086724645e33a575aafd0b1dae2835c37d2c00c6a0a96ee3a0
Successfully built unidic-lite
Installing collected packages: unidic-lite
Successfully installed unidic-lite-1.0.7


In [34]:
import MeCab
import subprocess

cmd='echo `mecab-config --dicdir`"/mecab-ipadic-neologd"'
path = (subprocess.Popen(cmd, stdout=subprocess.PIPE,
                           shell=True).communicate()[0]).decode('utf-8')
mecab = MeCab.Tagger("-d {0} -Owakati".format(path))
#mecab = MeCab.Tagger()
print(mecab.parse("彼女はペンパイナッポーアッポーペンと恋ダンスを踊った。"))

彼女 は ペンパイナッポーアッポーペン と 恋ダンス を 踊っ た 。 



In [35]:
import unicodedata
import re

#単語分割、id振り
def indexesFromSentence(sentence):
    return [SRC.vocab.stoi[word] if word in SRC.vocab.stoi else SRC.vocab.stoi['<unk>'] for word in mecab.parse(sentence)]

def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

def normalizeString(text):
    s = unicodeToAscii(text.lower().strip())
    s = re.sub(r"([.!?])", r"\1", text)
    s = re.sub(r"\s+", r" ", s).strip()
    s = re.sub(r'[【】]', '', s)                  # 【】の除去
    s = re.sub(r'[（）()]', '', s)                # （）の除去
    s = re.sub(r'[［］\[\]]', '', s)              # ［］の除去
    s = re.sub(r'[\r]', '', s)
    s = re.sub(r'　', ' ', s)                    #全角空白の除去
    s = re.sub(r'[^a-zA-Zぁ-んァ-ン一-龥0-9、。,.!?ー ]', '', s)
    return s


def evaluate(encoder, decoder, searcher, sentence, max_length):
    indexes_batch = indexesFromSentence(sentence)
    lengths = len(indexes_batch)
    input_batch = torch.LongTensor(indexes_batch).view(1, -1).to(device)
    tokens, scores = searcher(input_batch, lengths, max_length)
    decoded_words = [TRG.vocab.itos[token.item()] for token in tokens]
    return decoded_words

def evaluate_beam(encoder, decoder, searcher, sentence, max_length=MAX_LENGTH):
    indexes_batch = indexesFromSentence(sentence) #[]
    lengths = len(indexes_batch)
    input_batch = torch.LongTensor(indexes_batch).view(1, -1).to(device) #[1, ]
    encoder_outputs, decoder_hidden = encoder(input_batch, label2one_hot([0])) #[64, 30, 600], [1, 64, 600]

    tokens, scores = searcher(input_batch, lengths, max_length)
    decoded_words = [TRG.vocab.itos[token.item()] for token in tokens]
    return decoded_words  


def evaluateInput(encoder, decoder, searcher):
    input_sentence = ''
    while(1):
      try:
        input_sentence = input('> ').lower()
        if input_sentence == 'q' or input_sentence == 'quit': break
        #前処理
        input_sentence = normalizeString(input_sentence)
        output_words = evaluate(encoder, decoder, searcher, input_sentence, MAX_LENGTH)
        output_words[:] = [x for x in output_words if not (x == '<eos>' or x == '<pad>')]
        print('Bot:', ' '.join(output_words))

      except KeyError:
        print("Error: Encountered unknown word.")

In [44]:
import numpy as np

class GreedySearchDecoder(nn.Module):
  def __init__(self, encoder, decoder):
    super(GreedySearchDecoder, self).__init__()
    self.encoder = encoder
    self.decoder = decoder

  def forward(self, input_seq, input_length, max_length):
    encoder_outputs, thought_vector = self.encoder(input_seq)
    cn = torch.zeros(2, 1, hidden_size).to(device)
    decoder_hidden = (thought_vector, cn)
    decoder_input = torch.LongTensor([TRG.vocab.stoi['<cls>']]).to(device)
    all_tokens = torch.zeros([0], device=device, dtype=torch.long)
    all_scores = torch.zeros([0], device=device)
    for _ in range(max_length):
      decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
      decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
      all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
      all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            
    return all_tokens, all_scores

tempreture = 0.3

class GreedyTempreture(nn.Module):
  def __init__(self, encoder, decoder):
    super(GreedyTempreture, self).__init__()
    self.encoder = encoder
    self.decoder = decoder

  def forward(self, input_seq, input_length, max_length):
    encoder_outputs, thought_vector = self.encoder(input_seq)
    cn = torch.zeros(2, 1, hidden_size).to(device)
    decoder_hidden = (thought_vector, cn)
    decoder_input = torch.LongTensor([TRG.vocab.stoi['<cls>']]).to(device)
    all_tokens = torch.zeros([0], device=device, dtype=torch.long)
    all_scores = torch.zeros([0], device=device)
    for _ in range(max_length):
      decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)#[batch, vocab]
      decoder_output = decoder_output.squeeze(0)
      decoder_output = torch.log(decoder_output) / tempreture
      decoder_output = torch.exp(decoder_output)
      decoder_output = decoder_output / sum(decoder_output)
      decoder_input = torch.tensor([torch.multinomial(decoder_output ,1)], device=device, dtype=torch.long)#decoder_output
      all_tokens = torch.cat((all_tokens, decoder_input))
      #all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            
    return all_tokens, _

lamb = 0.6 
gamma = 5

class MMI(nn.Module):
  def __init__(self, encoder, decoder):
    super(MMI, self).__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.lm_decoder = decoder

  def forward(self, input_seq, input_length, max_length):
    #seq2seq用
    encoder_outputs, thought_vector = self.encoder(input_seq)
    cn = torch.zeros(2, 1, hidden_size).to(device)
    decoder_hidden = (thought_vector, cn)
    #言語モデル用
    lm_decoder_hidden = (torch.zeros(2, 1, hidden_size).to(device), cn)
    lm_encoder_outputs = torch.zeros(1, MAX_LENGTH, hidden_size).to(device)
    #共通で使う
    decoder_input = torch.LongTensor([TRG.vocab.stoi['<cls>']]).to(device)
    all_tokens = torch.zeros([0], device=device, dtype=torch.long)
    all_scores = torch.zeros([0], device=device)
    for i in range(max_length):
      if i <= gamma-1:
        #seq2seqの出力
        decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
        #言語モデルの出力
        lm_decoder_output, lm_decoder_hidden = self.lm_decoder(decoder_input, lm_decoder_hidden, lm_encoder_outputs)
        #
        #decoder_input = torch.log(decoder_output) - lamb*torch.log(lm_decoder_output)
        decoder_scores, decoder_input = torch.max(decoder_input, dim=1)
        all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
        all_scores = torch.cat((all_scores, decoder_scores), dim=0)
      else:
        decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
        decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
        all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
        all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            
    return all_tokens, all_scores

In [45]:
encoder.eval()
decoder.eval()

#searcher = GreedySearchDecoder(encoder, decoder)
#searcher = GreedyTempreture(encoder, decoder)
searcher = MMI(encoder, decoder)
evaluateInput(encoder, decoder, searcher)

> 頼む、うまく学習できてくれ
Bot: <unk> <unk> <unk> <unk> <unk>
> あーあもうだめだよ
Bot: <unk> <unk> <unk> <unk> <unk>
> こんにちは！
Bot: <unk> <unk> <unk> <unk> <unk>


KeyboardInterrupt: ignored

In [46]:

import pickle

model_path = '/content/drive/My Drive/model/TIU_encoder0829.pth'
torch.save(encoder.to('cpu').state_dict(), model_path)
model_path = '/content/drive/My Drive/model/TIU_decoder0829.pth'
torch.save(decoder.to('cpu').state_dict(), model_path)

model_path = '/content/drive/My Drive/model/cuda_TIU_encoder0829.pth'
torch.save(encoder.to('cuda').state_dict(), model_path)
model_path = '/content/drive/My Drive/model/cuda_TIU_decoder0829.pth'
torch.save(decoder.to('cuda').state_dict(), model_path)

with open('/content/drive/My Drive/model/src_word2index0829.pkl', 'wb') as f:
    pickle.dump(SRC.vocab.stoi, f)
with open('/content/drive/My Drive/model/trg_word2index0829.pkl', 'wb') as f:
    pickle.dump(TRG.vocab.itos, f)