In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [2]:
import torchtext
from torchtext.data.utils import get_tokenizer

MAX_LENGTH = 40

#テキストに処理を行うFieldを定義
#fix_lengthはtokenの数
SRC = torchtext.data.Field(sequential=True, use_vocab=True, 
                            lower=True, include_lengths=True, batch_first=True, fix_length=MAX_LENGTH,
                            eos_token='<eos>')

TRG = torchtext.data.Field(sequential=True, use_vocab=True, 
                            lower=True, include_lengths=True, batch_first=True, fix_length=MAX_LENGTH,
                            eos_token='<eos>')

#pandasでcsvを保存するときに、labelをintでキャストしておかないとエラーでるから注意
train_ds, val_ds = torchtext.data.TabularDataset.splits(
    path='/content/drive/My Drive/dataset/TIU/twitter', train='train.csv', validation='val.csv', 
    format='csv', fields=[('src', SRC), ('trg', TRG)])

In [3]:
SRC.build_vocab(train_ds, vectors=english_fasttext_vectors)
TRG.build_vocab(train_ds, vectors=english_fasttext_vectors)
SRC.build_vocab(train_ds)
TRG.build_vocab(train_ds)
print(TRG.vocab.stoi)
print(len(TRG.vocab.stoi))

defaultdict(<function _default_unk_index at 0x7fd43c13b488>, {'<unk>': 0, '<pad>': 1, '<eos>': 2, '。': 3, 'です': 4, 'て': 5, '、': 6, 'の': 7, 'は': 8, '私': 9, 'た': 10, 'に': 11, 'も': 12, 'が': 13, 'ね': 14, 'ます': 15, 'で': 16, '!': 17, 'し': 18, 'な': 19, 'と': 20, 'よ': 21, 'か': 22, 'ない': 23, 'ん': 24, 'だ': 25, 'から': 26, '...': 27, 'を': 28, '️': 29, '?': 30, 'まし': 31, 'ござい': 32, 'さん': 33, 'てる': 34, 'ありがとう': 35, 'ので': 36, 'けど': 37, '!!': 38, 'って': 39, 'そう': 40, 'たら': 41, 'いい': 42, '笑': 43, '〜': 44, '・': 45, 'たい': 46, 'お': 47, 'w': 48, 'う': 49, '目': 50, '^^': 51, '好き': 52, '́': 53, '口': 54, 'ふた': 55, 'さ': 56, 'こと': 57, 'ω': 58, 'ちゃん': 59, 'い': 60, 'ば': 61, 'れ': 62, 'ー': 63, 'とか': 64, 'する': 65, 'なっ': 66, 'でも': 67, '嬉しい': 68, '見': 69, 'ください': 70, 'それ': 71, '今日': 72, 'ある': 73, '*': 74, 'ませ': 75, '人': 76, 'よう': 77, 'だけ': 78, '大丈夫': 79, '...。': 80, 'これ': 81, 'わ': 82, '!!!': 83, '方': 84, 'や': 85, 'なら': 86, 'でし': 87, '気': 88, '......': 89, 'だっ': 90, 'なり': 91, 'この': 92, '思っ': 93, 'やっ': 94, 'せ': 95, '今': 96,

In [6]:
from torchtext import data

batch_size = 128

train_dl = data.Iterator(train_ds, batch_size=batch_size, train=True)
val_dl = data.Iterator(val_ds, batch_size=batch_size, train=False, sort=False)
batch = next(iter(val_dl))
print(batch.src[0].shape)
print(batch.trg[0].shape)

torch.Size([128, 40])
torch.Size([128, 40])


In [7]:
class EncoderRNN(nn.Module):
  def __init__(self, hidden_size, vocab_size, dropout=0):
    super(EncoderRNN, self).__init__()
    self.hidden_size = hidden_size
    self.embedding = nn.Embedding(vocab_size, hidden_size)
    self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True)
    self.thought = nn.Linear(hidden_size*2, hidden_size)


  def forward(self, input_seq, hidden=None):
    embedded = self.embedding(input_seq) #[64, 30, 600]
    outputs, (hn, cn) = self.lstm(embedded) #[64, 30, 1200], ([2, 64, 600], [2, 64, 600])
    outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:] #[64, 30, 600]
    thought_vector = torch.cat((hn[0], hn[1]), -1) #[64, 1200]
    thought_vector = self.thought(thought_vector).unsqueeze(0) #[1, 64, 600]

    return outputs, thought_vector

In [10]:
class LuongAttnDecoderRNN(nn.Module):
  def __init__(self, hidden_size, output_size, dropout=0.1):
    super(LuongAttnDecoderRNN, self).__init__()
    self.hidden_size = hidden_size
    self.output_size = output_size
    self.dropout = dropout

    self.embedding = nn.Embedding(output_size, hidden_size)
    self.embedding_dropout = nn.Dropout(dropout)
    self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
    self.score = nn.Linear(hidden_size, hidden_size)
    self.concat = nn.Linear(hidden_size * 2, hidden_size)
    self.out = nn.Linear(hidden_size, output_size)
    
  def forward(self, input_step, decoder_hidden, encoder_outputs):
    embedded = self.embedding(input_step)
    embedded = self.embedding_dropout(embedded)
    embedded = embedded.unsqueeze(1) #[64, 1, 600]
    
    #記憶セルはencoderから引っ張ってこない
    rnn_output, hidden = self.lstm(embedded, decoder_hidden) #[64, 1, 600] ([1, 64, 600], [1, 64, 600])
    energy = self.score(encoder_outputs) # [64, 30, 600]
    attn_weights = torch.sum(rnn_output*energy, dim=2) #[64, 30]
    attn_weights = F.softmax(attn_weights, dim=1).unsqueeze(1) # [64, 1, 30]

    context = attn_weights.bmm(encoder_outputs) #[64, 1, 500]
    rnn_output = rnn_output.squeeze(1) #[64, 500]
    context = context.squeeze(1) #[64, 500]
    concat_input = torch.cat((rnn_output, context), 1) #[64, 1000]
    concat_output = torch.tanh(self.concat(concat_input))
    output = self.out(concat_output)
    output = F.softmax(output, dim=1)

    return output, hidden

In [12]:
def binaryMatrix(l, value=TRG.vocab.stoi['<pad>']):
    m = []
    for i, seq in enumerate(l):
      if seq == TRG.vocab.stoi['<pad>']:
        m.append(False)
      else:
        m.append(True)
    return m

def maskNLLLoss(inp, target):
    mask = target
    mask = binaryMatrix(mask)
    mask = torch.BoolTensor(mask)
    mask = mask.to(device)
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()