In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
#学習済みの分散表現をロードする
from torchtext.vocab import Vectors

english_fasttext_vectors = Vectors(name='drive/My Drive/wiki-news-300d-1M.vec')

print(english_fasttext_vectors.dim)
print(len(english_fasttext_vectors.itos))

300
999994


In [5]:
import string
import re

# 以下の記号はスペースに置き換えます（カンマ、ピリオドを除く）。
# punctuationとは日本語で句点という意味です
print("区切り文字：", string.punctuation)
# !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~

# 前処理
def preprocessing_text(text):
    # 改行コードを消去
    text = re.sub('<br />', '', text)
    # カンマ、ピリオド以外の記号をスペースに置換
    for p in string.punctuation:
        if (p == ".") or (p == ","):
            continue
        else:
            text = text.replace(p, " ")
    # ピリオドなどの前後にはスペースを入れておく
    text = text.replace(".", " . ")
    text = text.replace(",", " , ")
    return text

# 分かち書き（今回はデータが英語で、簡易的にスペースで区切る）
def tokenizer_punctuation(text):
    return text.strip().split()

# 前処理と分かち書きをまとめた関数を定義
def tokenizer_with_preprocessing(text):
    text = preprocessing_text(text)
    ret = tokenizer_punctuation(text)
    return ret


# 動作を確認します
print(tokenizer_with_preprocessing('I like cats+'))

区切り文字： !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~
['I', 'like', 'cats']


In [6]:
import torchtext
from torchtext.data.utils import get_tokenizer

MAX_LENGTH = 30

#テキストに処理を行うFieldを定義
#fix_lengthはtokenの数
SRC = torchtext.data.Field(sequential=True, use_vocab=True, tokenize=tokenizer_with_preprocessing,
                            lower=True, include_lengths=True, batch_first=True, fix_length=MAX_LENGTH,
                            eos_token='<eos>')

TRG = torchtext.data.Field(sequential=True, use_vocab=True, tokenize=tokenizer_with_preprocessing,
                            lower=True, include_lengths=True, batch_first=True, fix_length=MAX_LENGTH,
                            eos_token='<eos>')

LABEL_SRC = torchtext.data.Field(sequential=False, use_vocab=False)

LABEL_TRG = torchtext.data.Field(sequential=False, use_vocab=False)

#pandasでcsvを保存するときに、labelをintでキャストしておかないとエラーでるから注意
train_ds, val_ds, test_ds = torchtext.data.TabularDataset.splits(
    path='drive/My Drive/dataset/DailyDialog/', train='train.csv', validation='validation.csv', 
    test='test.csv', format='csv', fields=[('src', SRC), ('trg', TRG), ('label_src', LABEL_SRC), ('label_trg', LABEL_TRG)])

In [20]:
SRC.build_vocab(train_ds, vectors=english_fasttext_vectors)
TRG.build_vocab(train_ds, vectors=english_fasttext_vectors)
#SRC.build_vocab(train_ds)
#TRG.build_vocab(train_ds)
print(TRG.vocab.stoi)
print(len(TRG.vocab.stoi))

13041


In [8]:
from torchtext import data

batch_size = 64

train_dl = data.Iterator(train_ds, batch_size=batch_size, train=True)
val_dl = data.Iterator(val_ds, batch_size=batch_size, train=False, sort=False)
batch = next(iter(val_dl))
print(batch.src[0].shape)
print(batch.trg[0].shape)
print(batch.label_src.shape)

torch.Size([64, 30])
torch.Size([64, 30])
torch.Size([64])


In [21]:
class EncoderRNN(nn.Module):
  def __init__(self, emb_size, hidden_size, vocab_size, text_embedding_vectors, emotion_size, dropout=0):
    super(EncoderRNN, self).__init__()
    self.hidden_size = hidden_size
    if text_embedding_vectors == None:
      self.embedding = nn.Embedding(vocab_size, emb_size)
    else:
      self.embedding = nn.Embedding.from_pretrained(
          embeddings=text_embedding_vectors, freeze=True)
    self.embedding_dropout = nn.Dropout(dropout)
    self.lstm = nn.LSTM(emb_size, hidden_size, batch_first=True, bidirectional=True)
    self.thought = nn.Linear(hidden_size*2+emotion_size, hidden_size)


  def forward(self, input_seq, emotion, hidden=None):
    embedded = self.embedding(input_seq) #[64, 30, 600]
    #embedded = self.embedding_dropout(embedded)
    outputs, (hn, cn) = self.lstm(embedded) #[64, 30, 1200], ([2, 64, 600], [2, 64, 600])
    outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:] #[64, 30, 600]
    thought_vector = torch.cat((hn[0], hn[1], emotion), -1) #[64, 1200]
    thought_vector = self.thought(thought_vector).unsqueeze(0) #[1, 64, 600]

    return outputs, thought_vector

In [10]:
def label2one_hot(emotion):
  result = torch.zeros(batch_size, emotion_size)
  for i, e in enumerate(emotion):
    result[i][e] = 1.
  return result.to(device)

In [22]:
class LuongAttnDecoderRNN(nn.Module):
  def __init__(self, emb_size, hidden_size, text_embedding_vectors, output_size, dropout=0.1):
    super(LuongAttnDecoderRNN, self).__init__()
    self.hidden_size = hidden_size
    self.output_size = output_size
    self.dropout = dropout
    if text_embedding_vectors == None:
      self.embedding = nn.Embedding(output_size, emb_size)
    else:
      self.embedding = nn.Embedding.from_pretrained(
          embeddings=text_embedding_vectors, freeze=True)
    self.embedding_dropout = nn.Dropout(dropout)
    self.lstm = nn.LSTM(emb_size, hidden_size, batch_first=True)
    self.score = nn.Linear(hidden_size, hidden_size)
    self.concat = nn.Linear(hidden_size * 2, hidden_size)
    self.concat_dropout = nn.Dropout(dropout)
    self.out = nn.Linear(hidden_size, output_size)
    
  def forward(self, input_step, decoder_hidden, encoder_outputs):
    embedded = self.embedding(input_step)
    embedded = self.embedding_dropout(embedded)
    embedded = embedded.unsqueeze(1) #[64, 1, 600]
    
    #記憶セルはencoderから引っ張ってこない
    rnn_output, hidden = self.lstm(embedded, decoder_hidden) #[64, 1, 600] ([1, 64, 600], [1, 64, 600])
    energy = self.score(encoder_outputs) # [64, 30, 600]
    attn_weights = torch.sum(rnn_output*energy, dim=2) #[64, 30]
    attn_weights = F.softmax(attn_weights, dim=1).unsqueeze(1) # [64, 1, 30]

    context = attn_weights.bmm(encoder_outputs) #[64, 1, 600]
    rnn_output = rnn_output.squeeze(1) #[64, 600]
    context = context.squeeze(1) #[64, 600]
    concat_input = torch.cat((rnn_output, context), 1) #[64, 1200]
    #concat_input = self.concat_dropout(concat_dropout)
    concat_output = torch.tanh(self.concat(concat_input))
    output = self.out(concat_output)
    output = F.softmax(output, dim=1)

    return output, hidden

In [13]:
def binaryMatrix(l, value=TRG.vocab.stoi['<pad>']):
    m = []
    for i, seq in enumerate(l):
      if seq == TRG.vocab.stoi['<pad>']:
        m.append(False)
      else:
        m.append(True)
    return m

def maskNLLLoss(inp, target):
    mask = target
    mask = binaryMatrix(mask)
    mask = torch.BoolTensor(mask)
    mask = mask.to(device)
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()
  

In [14]:
"""
decoder = LuongAttnDecoderRNN(hidden_size, len(TRG.vocab.stoi),  dropout)
decoder_input = torch.LongTensor([TRG.vocab.stoi['<cls>'] for _ in range(batch_size)])
embedding = nn.Embedding(len(SRC.vocab.stoi), hidden_size)
print(decoder_input.shape)

cn = torch.zeros(1, batch_size, hidden_size)
#cn = torch.zeros(1, batch_size, hidden_size, device=device)
decoder_hidden = (thought_vector, cn)
target_variable = batch.trg[0]


for t in range(30):
  decoder_output, decoder_hidden = decoder(
      decoder_input, decoder_hidden, encoder_outputs
  ) #[64, 単語種類数], [2, 64, 500]
  # Teacher forcing: next input is current target
  _, topi = decoder_output.topk(1)
  decoder_input = torch.LongTensor([topi[i] for i in range(batch_size)])
  mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[:, t])
  
decoder_output, decoder_hidden = decoder(
    decoder_input, decoder_hidden, encoder_outputs
) #[64, 単語種類数], [2, 64, 500]

print('decoder_output', decoder_output.shape)
print('decoder_hidden', decoder_hidden[0].shape)
"""

"\ndecoder = LuongAttnDecoderRNN(hidden_size, len(TRG.vocab.stoi),  dropout)\ndecoder_input = torch.LongTensor([TRG.vocab.stoi['<cls>'] for _ in range(batch_size)])\nembedding = nn.Embedding(len(SRC.vocab.stoi), hidden_size)\nprint(decoder_input.shape)\n\ncn = torch.zeros(1, batch_size, hidden_size)\n#cn = torch.zeros(1, batch_size, hidden_size, device=device)\ndecoder_hidden = (thought_vector, cn)\ntarget_variable = batch.trg[0]\n\n\nfor t in range(30):\n  decoder_output, decoder_hidden = decoder(\n      decoder_input, decoder_hidden, encoder_outputs\n  ) #[64, 単語種類数], [2, 64, 500]\n  # Teacher forcing: next input is current target\n  _, topi = decoder_output.topk(1)\n  decoder_input = torch.LongTensor([topi[i] for i in range(batch_size)])\n  mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[:, t])\n  \ndecoder_output, decoder_hidden = decoder(\n    decoder_input, decoder_hidden, encoder_outputs\n) #[64, 単語種類数], [2, 64, 500]\n\nprint('decoder_output', decoder_output.shap

In [15]:
def a_batch_loss(input_variable, target_variable, emotion, max_target_len, encoder, decoder, 
                 encoder_optimizer, decoder_optimizer, phase):
  total_loss = 0 #1batchのloss
  # Zero gradients
  encoder_optimizer.zero_grad()
  decoder_optimizer.zero_grad()
  n_totals = 0
  print_losses = []
  
  #エンコーダの出力
  encoder_outputs, thought_vector = encoder(input_variable, emotion)
  #['<cls>']を生成
  decoder_input = torch.LongTensor([TRG.vocab.stoi['<cls>'] for _ in range(batch_size)]) #[64]
  decoder_input = decoder_input.to(device)
  #エンコーダの最後の隠れ状態を使用、記憶セルは0を入力
  cn = torch.zeros(1, batch_size, hidden_size, device=device)
  decoder_hidden = (thought_vector, cn)

  #teaching_forceを使う
  loss = 0 #1batchの中の1センテンスのloss
  use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

  if use_teacher_forcing:
    for t in range(max_target_len):
      decoder_output, decoder_hidden = decoder(
          decoder_input, decoder_hidden, encoder_outputs
      ) #[64, 単語種類数], [2, 64, 500]
      # Teacher forcing: next input is current target
      decoder_input = target_variable[:, t] #[64], teaching_forceの場合、正解データを次に入力する
      #loss += criterion(decoder_output, target_variable[:, t])
      mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[:, t])
      loss += mask_loss
      print_losses.append(mask_loss.item() * nTotal)
      n_totals += nTotal
    #total_loss += loss / max_target_len #1バッチ分のloss
    
  else:
    for t in range(max_target_len):
      decoder_output, decoder_hidden = decoder(
          decoder_input, decoder_hidden, encoder_outputs
      ) #[64, 単語種類数], [2, 64, 500]
      # Teacher forcing: next input is current target
      _, topi = decoder_output.topk(1)
      decoder_input = torch.LongTensor([topi[i] for i in range(batch_size)])
      decoder_input = decoder_input.to(device)
      #loss += criterion(decoder_output, target_variable[:, t])
      mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[:, t])
      loss += mask_loss
      print_losses.append(mask_loss.item() * nTotal)
      n_totals += nTotal
    #total_loss += (loss / max_target_len) #1バッチ分のloss
    
  if phase == 'train':
    loss.backward()
    #total_loss.backward()
    _ = nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = nn.utils.clip_grad_norm_(decoder.parameters(), clip)

    encoder_optimizer.step()
    decoder_optimizer.step()
  return sum(print_losses) / n_totals
  #return total_loss #1バッチ分のloss

In [16]:
import random

def train_model(dataloaders_dict, num_epochs, encoder, decoder, encoder_optimizer, decoder_optimizer):
  print("Training...")
  #エポック
  for epoch in range(num_epochs):
    for phase in ['train', 'val']:
      if phase == 'train':
        encoder.train()
        decoder.train()
      else:
        encoder.eval()
        decoder.eval()
      print_loss = 0 #1epochのloss

      for i, batch in enumerate(dataloaders_dict[phase]): 
        input_variable = batch.src[0].to(device) #(64, 30)
        target_variable = batch.trg[0].to(device) #(64, 30)
        emotion = batch.label_src
        emotion = label2one_hot(emotion)
        max_target_len = max(batch.trg[1])
        if target_variable.shape[0] == 64:
          total_loss = a_batch_loss(input_variable, target_variable, emotion, max_target_len, encoder, decoder, encoder_optimizer, decoder_optimizer, phase) #1バッチ分のloss     
          print_loss += total_loss #1epochのlossをprint_lossに加えていく

      #損失をだす
      print("epoch: {}; phase: {}; Average loss: {:.4f}; PPL: {:.4f}".format(epoch+1, phase, print_loss/i, math.exp(print_loss/i) ))  

In [23]:
emb_size = 300
hidden_size = 600
dropout = 0.5
emotion_size = 7


#encoder = EncoderRNN(emb_size, hidden_size, len(SRC.vocab.stoi), None, emotion_size, dropout)
#decoder = LuongAttnDecoderRNN(emb_size, hidden_size, None, len(TRG.vocab.stoi),  dropout)

encoder = EncoderRNN(emb_size, hidden_size, len(SRC.vocab.stoi), SRC.vocab.vectors, emotion_size, dropout)
decoder = LuongAttnDecoderRNN(emb_size, hidden_size, TRG.vocab.vectors, len(TRG.vocab.stoi),  dropout)

encoder = encoder.to(device)
decoder = decoder.to(device)

In [24]:
from torch import optim

clip = 1.0
teacher_forcing_ratio = 1.0
learning_rate = 0.0001
decoder_learning_rate = 5.0
num_epochs = 6


dataloaders_dict = {"train": train_dl, "val": val_dl}

encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate*decoder_learning_rate )

encoder.train()
decoder.train()

LuongAttnDecoderRNN(
  (embedding): Embedding(13041, 300)
  (embedding_dropout): Dropout(p=0.5, inplace=False)
  (lstm): LSTM(300, 600, batch_first=True)
  (score): Linear(in_features=600, out_features=600, bias=True)
  (concat): Linear(in_features=1200, out_features=600, bias=True)
  (concat_dropout): Dropout(p=0.5, inplace=False)
  (out): Linear(in_features=600, out_features=13041, bias=True)
)

In [25]:
train_model(dataloaders_dict, num_epochs, encoder, decoder, encoder_optimizer, decoder_optimizer)

Training...
epoch: 1; phase: train; Average loss: 5.2105; PPL: 183.1815
epoch: 1; phase: val; Average loss: 4.6725; PPL: 106.9627
epoch: 2; phase: train; Average loss: 4.5370; PPL: 93.4103
epoch: 2; phase: val; Average loss: 4.3467; PPL: 77.2196
epoch: 3; phase: train; Average loss: 4.2693; PPL: 71.4738
epoch: 3; phase: val; Average loss: 4.1730; PPL: 64.9109
epoch: 4; phase: train; Average loss: 4.0873; PPL: 59.5798
epoch: 4; phase: val; Average loss: 4.0522; PPL: 57.5220
epoch: 5; phase: train; Average loss: 3.9467; PPL: 51.7651
epoch: 5; phase: val; Average loss: 3.9781; PPL: 53.4170
epoch: 6; phase: train; Average loss: 3.8297; PPL: 46.0509
epoch: 6; phase: val; Average loss: 3.9140; PPL: 50.1007


In [59]:
import pickle

model_path = '/content/drive/My Drive/dataset/DailyDialog/chatbot_encoder0809.pth'
torch.save(encoder.to('cpu').state_dict(), model_path)
model_path = '/content/drive/My Drive/dataset/DailyDialog/chatbot_decoder0809.pth'
torch.save(decoder.to('cpu').state_dict(), model_path)

model_path = '/content/drive/My Drive/dataset/DailyDialog/cuda_chatbot_encoder0809.pth'
torch.save(encoder.to('cuda').state_dict(), model_path)
model_path = '/content/drive/My Drive/dataset/DailyDialog/cuda_chatbot_decoder0809.pth'
torch.save(decoder.to('cuda').state_dict(), model_path)

with open('/content/drive/My Drive/dataset/DailyDialog/voc_word2index0809.pkl', 'wb') as f:
    pickle.dump(SRC.vocab.stoi, f)
with open('/content/drive/My Drive/dataset/DailyDialog/voc_index2word0809.pkl', 'wb') as f:
    pickle.dump(TRG.vocab.itos, f)

In [None]:
encoder = EncoderRNN(emb_size, hidden_size, len(SRC.vocab.stoi), emotion_size, dropout)
decoder = LuongAttnDecoderRNN(emb_size, hidden_size, len(TRG.vocab.stoi),  dropout)
encoder.load_state_dict(torch.load('/content/drive/My Drive/dataset/DailyDialog/cuda_chatbot_encoder0725.pth'))
decoder.load_state_dict(torch.load('/content/drive/My Drive/dataset/DailyDialog/cuda_chatbot_decoder0725.pth'))
encoder = encoder.to(device)
decoder = decoder.to(device)
encoder.eval()
decoder.eval()

LuongAttnDecoderRNN(
  (embedding): Embedding(13046, 600)
  (embedding_dropout): Dropout(p=0.1, inplace=False)
  (lstm): LSTM(600, 600, batch_first=True)
  (score): Linear(in_features=600, out_features=600, bias=True)
  (concat): Linear(in_features=1200, out_features=600, bias=True)
  (out): Linear(in_features=600, out_features=13046, bias=True)
)

In [25]:
class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_seq, emotion, input_length, max_length):
        encoder_outputs, thought_vector = self.encoder(input_seq, emotion)
        cn = torch.zeros(1, 1, hidden_size).to(device)
        decoder_hidden = (thought_vector, cn)
        decoder_input = torch.LongTensor([TRG.vocab.stoi['<cls>']]).to(device)
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)
        for _ in range(max_length):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            #decoder_input = torch.unsqueeze(decoder_input, 0)
            
        return all_tokens, all_scores

In [26]:
import unicodedata

def indexesFromSentence(sentence):
    return [SRC.vocab.stoi[word] if word in SRC.vocab.stoi else SRC.vocab.stoi['<unk>'] for word in sentence.split(' ')]
    
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

def evaluate(encoder, decoder, emotion, searcher, sentence, max_length=MAX_LENGTH):
    indexes_batch = indexesFromSentence(sentence)
    lengths = len(indexes_batch)
    input_batch = torch.LongTensor(indexes_batch).view(1, -1).to(device)
    tokens, scores = searcher(input_batch, emotion, lengths, max_length)
    decoded_words = [TRG.vocab.itos[token.item()] for token in tokens]
    return decoded_words

def evaluate_beam(encoder, decoder, searcher, sentence, max_length=MAX_LENGTH):
    indexes_batch = indexesFromSentence(sentence) #[]
    lengths = len(indexes_batch)
    input_batch = torch.LongTensor(indexes_batch).view(1, -1).to(device) #[1, ]
    encoder_outputs, decoder_hidden = encoder(input_batch, label2one_hot([0])) #[64, 30, 600], [1, 64, 600]

    tokens, scores = searcher(input_batch, lengths, max_length)
    decoded_words = [TRG.vocab.itos[token.item()] for token in tokens]
    return decoded_words  


def evaluateInput(encoder, decoder, searcher):
    input_sentence = ''
    while(1):
        try:
            inp = input('> ').lower()
            input_sentence = inp.split('<emo>')[0]
            emotion = label2one_hot([int(inp.split('<emo>')[1])])
            if input_sentence == 'q' or input_sentence == 'quit': break
            #ノイズ処理
            input_sentence = normalizeString(input_sentence)
            output_words = evaluate(encoder, decoder, emotion, searcher, input_sentence)
            output_words[:] = [x for x in output_words if not (x == '<eos>' or x == '<pad>')]
            print('Bot:', ' '.join(output_words))

        except KeyError:
            print("Error: Encountered unknown word.")

In [None]:
encoder.eval()
decoder.eval()

searcher = GreedySearchDecoder(encoder, decoder)
evaluateInput(encoder, decoder, searcher)

> an hour ago, I bought this vase with my tourist.<emo>1
Bot: i m sorry , but i m afraid it s a bit chilly in here . this scarf . more expensive than the new accord . that
> quit


IndexError: ignored

In [55]:
import copy
from heapq import heappush, heappop

class BeamSearchNode(object):
    def __init__(self, h, prev_node, wid, logp, length):
        self.h = h
        self.prev_node = prev_node
        self.wid = wid
        self.logp = logp
        self.length = length

    def eval(self):
        return self.logp / float(self.length - 1 + 1e-6)

def beam_search_decoding(decoder, encoder_output, thought_vector, beam_width, n_best, sos_token, eos_token, max_dec_steps, device):
    assert beam_width >= n_best

    n_best_list = []
    cn = torch.zeros(1, 1, 600).to(device)
    decoder_hidden = (thought_vector, cn) #((1,1,600), (1,1,600))
    decoder_input = torch.LongTensor([sos_token]).to(device) #[1]
    end_nodes = []
    node = BeamSearchNode(h=decoder_hidden, prev_node=None, wid=decoder_input, logp=0, length=1)
    nodes = []
    heappush(nodes, (-node.eval(), id(node), node))
    n_dec_steps = 0
    t = 0
    
    # Start beam search
    while True:
        #最大単語数を越したら
        if n_dec_steps > max_dec_steps:
            break

        # Fetch the best node
        score, _, n = heappop(nodes)
        decoder_input = n.wid #(1)
        print(t, TRG.vocab.itos[decoder_input])
        decoder_hidden = n.h  #((1,1,600), (1,1,600))   

        if n.wid.item() == eos_token and n.prev_node is not None:
            end_nodes.append((score, id(n), n))
            # If we reached maximum # of sentences required
            if len(end_nodes) >= n_best:
                break
            else:
                continue
        
         #(1, 単語数), ((1,1,600), (1,1,600))
        decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_output)

        topk_log_prob, topk_indexes = torch.topk(decoder_output, beam_width) # (1, bw), (1, bw)
        

        for new_k in range(beam_width):
            decoded_t = topk_indexes[0][new_k].view(1) # (1) new_k番目のindexを取り出す
            logp = topk_log_prob[0][new_k].item() # new_k番目の単語の生成確率を取り出す
            print(t, logp ,TRG.vocab.itos[decoded_t])         
            node = BeamSearchNode(h=decoder_hidden,
                                  prev_node=n,
                                  wid=decoded_t,
                                  logp=n.logp+logp,
                                  length=n.length+1)
            heappush(nodes, (-node.eval(), id(node), node))
        n_dec_steps += beam_width
        t = t+1
        
    #ここでnodesに追加する作業が終わる
    # if there are no end_nodes, retrieve best nodes (they are probably truncated)
    if len(end_nodes) == 0:
        end_nodes = [heappop(nodes) for _ in range(beam_width)]

    # Construct sequences from end_nodes
    n_best_seq_list = []
    for score, _id, n in sorted(end_nodes, key=lambda x: x[0]):
        sequence = [n.wid.item()]
        # back trace from end node
        while n.prev_node is not None:
            n = n.prev_node
            sequence.append(n.wid.item())
        sequence = sequence[::-1] # reverse

        n_best_seq_list.append(sequence)

    n_best_list.append(n_best_seq_list)
        
    return n_best_list

def print_n_best(decoded_seq, itos):
    for rank, seq in enumerate(decoded_seq):
        print(f'Out: Rank-{rank+1}: {" ".join([itos[idx] for idx in seq])}')

In [56]:
encoder.eval()
decoder.eval()

beam_width = 10
n_best = 5

def label2one_hot(emotion):
  result = torch.zeros(1, emotion_size)
  for i, e in enumerate(emotion):
    result[i][e] = 1.
  return result.to(device)

src = normalizeString('Oh, I’m sorry I bothered you. I’m really sorry.').lower()
indexes_batch = indexesFromSentence(src)
input_batch = torch.LongTensor(indexes_batch).view(1, -1).to(device)
emotion = label2one_hot([5])
encoder_outputs, thought_vector = encoder(input_batch, emotion)
n_best_list = beam_search_decoding(decoder, encoder_outputs, thought_vector, beam_width, n_best, TRG.vocab.stoi['<cls>'], TRG.vocab.stoi['<eos>'],  1000, device='cuda')
print_n_best(n_best_list[0], TRG.vocab.itos)

0 <unk>
0 0.193338543176651 i
0 0.07846338301897049 oh
0 0.04833661764860153 well
0 0.04073987901210785 what
0 0.03451855853199959 you
0 0.028839297592639923 it
0 0.028546806424856186 ok
0 0.02565399557352066 yes
0 0.024824874475598335 thank
0 0.021596916019916534 how
1 i
1 0.4404541254043579 m
1 0.07348179072141647 am
1 0.06939078867435455 don
1 0.06604773551225662 ve
1 0.03658263385295868 ll
1 0.021328551694750786 didn
1 0.021228346973657608 know
1 0.017208857461810112 was
1 0.017042480409145355 d
1 0.013192374259233475 thought
2 m
2 0.24160611629486084 sorry
2 0.19874033331871033 glad
2 0.09925078600645065 afraid
2 0.04813940078020096 not
2 0.035299062728881836 sure
2 0.02696317434310913 going
2 0.020230866968631744 so
2 0.017035912722349167 very
2 0.009665507823228836 fine
2 0.008129902184009552 happy
3 sorry
3 0.2920418977737427 to
3 0.2686862051486969 .
3 0.26756495237350464 ,
3 0.03736795857548714 i
3 0.021393297240138054 about
3 0.011193566955626011 you
3 0.00901197548955679 if

In [53]:
import copy
from heapq import heappush, heappop

class BeamSearchNode(object):
    def __init__(self, h, prev_node, wid, logp, length):
        self.h = h
        self.prev_node = prev_node
        self.wid = wid
        self.logp = logp
        self.length = length

    def eval(self):
        return self.logp / float(self.length - 1 + 1e-6)

def my_beam_search(decoder, encoder_output, thought_vector, beam_width, sos_token, eos_token, max_dec_steps, device):
    #assert beam_width >= n_best

    n_best_list = []
    
    cn = torch.zeros(1, 1, 600).to(device)
    decoder_hidden = (thought_vector, cn) #((1,1,600), (1,1,600))
    #<cls>を作成
    decoder_input = torch.LongTensor([sos_token]).to(device) #[1]
    end_nodes = []
    node = BeamSearchNode(h=decoder_hidden, prev_node=None, wid=decoder_input, logp=0, length=1)
    nodes = []
    #初期状態を追加
    heappush(nodes, (-node.eval(), id(node), node))
    n_dec_steps = 0
    #eosに到達した文を数える
    eos_count = 0
    
    #Repeat
    for t in range(max_dec_steps):
        #全ての候補文がeosなら終わり
        if len(end_nodes) >= beam_width:
            break
        t_list = []
        #beam幅分の候補で考える
        for _ in range(len(nodes)):
            #一番スコアの高い文章をだす
            score, _, n = heappop(nodes)
            decoder_input = n.wid #単語ID
            decoder_hidden = n.h  #LSTMに入れる隠れ状態と記憶せる
            
            #<eos>出力されたらこれ以上は出力しない
            if n.wid.item() == eos_token and n.prev_node is not None:
                if (score, id(n), n) not in end_nodes:
                      end_nodes.append((score, id(n), n))
                continue
            #候補文をデコーダに
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_output)
            topk_log_prob, topk_indexes = torch.topk(decoder_output, beam_width)
            #t時刻目のbeam幅分の候補単語とindexをまとめていく
            t_list.append((topk_log_prob, topk_indexes, n))
        
        #t時刻目の候補単語beam_width個を使ってそれぞれ出力候補を確認
        for topk_log_prob, topk_indexes, n in t_list:
            for new_k in range(beam_width):
                decoded_t = topk_indexes[0][new_k].view(1) # (1) new_k番目のindexを取り出す
                logp = topk_log_prob[0][new_k].item() # new_k番目の単語の生成確率を取り出す 
                node = BeamSearchNode(h=decoder_hidden,
                                        prev_node=n,
                                        wid=decoded_t,
                                        logp=n.logp+logp,
                                        length=n.length+1)
                heappush(nodes, (-node.eval(), id(node), node))
        #nodesの中からbeam_width個を選択して次の時刻へ
        nodes = nodes[:beam_width]
            
    # if there are no end_nodes, retrieve best nodes (they are probably truncated)
    #if len(end_nodes) == 0:
    #   end_nodes = [heappop(nodes) for _ in range(beam_width)]
    end_nodes = nodes
    
    # Construct sequences from end_nodes
    n_best_seq_list = []
    for score, _id, n in sorted(end_nodes, key=lambda x: x[0]):
        sequence = [n.wid.item()]
        # back trace from end node
        while n.prev_node is not None:
            n = n.prev_node
            sequence.append(n.wid.item())
        sequence = sequence[::-1] # reverse

        n_best_seq_list.append(sequence)

    n_best_list.append(n_best_seq_list)
        
    return n_best_list

def print_n_best(decoded_seq, itos):
    for rank, seq in enumerate(decoded_seq):
        print(f'Out: Rank-{rank+1}: {" ".join([itos[idx] for idx in seq])}')          

In [58]:
encoder.eval()
decoder.eval()

beam_width = 5
sos_token = TRG.vocab.stoi['<cls>']
eos_token = TRG.vocab.stoi['<eos>']
max_dec_steps = 30

def label2one_hot(emotion):
  result = torch.zeros(1, emotion_size)
  for i, e in enumerate(emotion):
    result[i][e] = 1.
  return result.to(device)

src = normalizeString('Oh, I’m sorry I bothered you. I’m really sorry.').lower()
indexes_batch = indexesFromSentence(src)
input_batch = torch.LongTensor(indexes_batch).view(1, -1).to(device)
emotion = label2one_hot([5])
encoder_outputs, thought_vector = encoder(input_batch, emotion)

n_best_list = my_beam_search(decoder, encoder_outputs, thought_vector, beam_width, sos_token, eos_token, max_dec_steps, device)
print_n_best(n_best_list[0], TRG.vocab.itos)

0 <unk>
0 0.193338543176651 i
0 0.07846338301897049 oh
0 0.04833661764860153 well
0 0.04073987901210785 what
0 0.03451855853199959 you
1 i
1 oh
1 well
1 what
1 you
1 0.4404541254043579 m
1 0.07348179072141647 am
1 0.06939078867435455 don
1 0.06604773551225662 ve
1 0.03658263385295868 ll
1 0.857610821723938 ,
1 0.04525858163833618 .
1 0.028159772977232933 i
1 0.010917921550571918 dear
1 0.005687440279871225 yeah
1 0.8767769932746887 ,
1 0.06446512043476105 .
1 0.02320067025721073 i
1 0.003280176315456629 it
1 0.0022508788388222456 what
1 0.08715584129095078 s
1 0.08712264150381088 do
1 0.08402291685342789 are
1 0.06981416791677475 is
1 0.0623164027929306 did
1 0.20782603323459625 re
1 0.11795832216739655 are
1 0.07846268266439438 know
1 0.06262461096048355 ve
1 0.05818907916545868 mean
2 ,
2 ,
2 m
2 am
2 ve
2 0.20594905316829681 i
2 0.030408266931772232 what
2 0.026398595422506332 mark
2 0.025859219953417778 it
2 0.021568644791841507 thank
2 0.20594905316829681 i
2 0.030408266931772232 