In [1]:
!pip install torchvision==0.2.1
!pip install torchtext==0.3.1

Collecting torchvision==0.2.1
[?25l  Downloading https://files.pythonhosted.org/packages/ca/0d/f00b2885711e08bd71242ebe7b96561e6f6d01fdb4b9dcf4d37e2e13c5e1/torchvision-0.2.1-py2.py3-none-any.whl (54kB)
[K    100% |████████████████████████████████| 61kB 5.1MB/s 
Collecting torch (from torchvision==0.2.1)
[?25l  Downloading https://files.pythonhosted.org/packages/49/0e/e382bcf1a6ae8225f50b99cc26effa2d4cc6d66975ccf3fa9590efcbedce/torch-0.4.1-cp36-cp36m-manylinux1_x86_64.whl (519.5MB)
[K    100% |████████████████████████████████| 519.5MB 29kB/s 
tcmalloc: large alloc 1073750016 bytes == 0x58c7c000 @  0x7f390763a2a4 0x594e17 0x626104 0x51190a 0x4f5277 0x510c78 0x5119bd 0x4f5277 0x4f3338 0x510fb0 0x5119bd 0x4f5277 0x4f3338 0x510fb0 0x5119bd 0x4f5277 0x4f3338 0x510fb0 0x5119bd 0x4f6070 0x510c78 0x5119bd 0x4f5277 0x4f3338 0x510fb0 0x5119bd 0x4f6070 0x4f3338 0x510fb0 0x5119bd 0x4f6070
[?25hCollecting pillow>=4.1.1 (from torchvision==0.2.1)
[?25l  Downloading https://files.pythonhosted.org

In [0]:
!wget --quiet --continue http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
!tar -xzf simple-examples.tgz

In [3]:
!ls simple-examples

1-train		   5-one-iter		       9-char-based-lm	temp
2-nbest-rescore    6-recovery-during-training  data
3-combination	   7-dynamic-evaluation        models
4-data-generation  8-direct		       rnnlm-0.2b


In [0]:
!mv ./simple-examples/data/ptb.train.txt train.txt
!mv ./simple-examples/data/ptb.valid.txt valid.txt
!mv ./simple-examples/data/ptb.test.txt test.txt

In [5]:
!ls

sample_data	 simple-examples.tgz  train.txt
simple-examples  test.txt	      valid.txt


In [0]:
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import torch.nn.functional as F
from torchtext import data
from torchtext import vocab
from torchtext import datasets

%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt

In [7]:
# データとモデルに.to(device)を指定してgpuの計算資源を使用する。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [10]:
# 前処理用の機能のFieldをセットアップ
#Field
TEXT = data.Field(batch_first=True)
#LabelField
LABEL = data.LabelField()
# データを取得
train_dataset, val_dataset, test_dataset = datasets.LanguageModelingDataset.splits(path="."
                                        , train="train.txt"
                                        , validation="valid.txt"
                                        , test="test.txt"
                                        , text_field=TEXT)

TEXT.build_vocab(train_dataset, vectors=vocab.GloVe(name='6B', dim=300))

.vector_cache/glove.6B.zip: 862MB [02:31, 5.70MB/s]                           
100%|█████████▉| 399951/400000 [00:45<00:00, 8694.83it/s]

In [11]:
#全単語数
vocab_size = len(TEXT.vocab)
print(vocab_size)
# 単語の件数のtop10
print(TEXT.vocab.freqs.most_common(10))
# 単語
print(TEXT.vocab.itos[:10])

#埋め込みベクトルを取得
word_embeddings = TEXT.vocab.vectors
# ハイパーパラメータ
embedding_length = 300
hidden_size = 256
batch_size = 32

10001
[('the', 50770), ('<unk>', 45020), ('<eos>', 42068), ('N', 32481), ('of', 24400), ('to', 23638), ('a', 21196), ('in', 18000), ('and', 17474), ("'s", 9784)]
['<unk>', '<pad>', 'the', '<eos>', 'N', 'of', 'to', 'a', 'in', 'and']


In [12]:
# BPTTIteratorは言語モデル用のイテレータ作成を行います。
# textとtarget属性を持ちます。
train_iter, val_iter, test_iter = data.BPTTIterator.splits((train_dataset, val_dataset, test_dataset)
                                                           , batch_size=32,  bptt_len=30, repeat=False)

print(len(train_iter))
print(len(val_iter))
print(len(test_iter))

969
77
86


In [13]:
for i, train in enumerate(train_iter):
    print("データの形状確認")
    print(train.text.size())
    print(train.target.size())
    print("permuteでバッチを先にする")
    print(train.text.permute(1, 0).size())
    print(train.target.permute(1, 0).size())
    print("１データ目の形状とデータを確認")
    text = train.text.permute(1, 0)
    target = train.target.permute(1, 0)
    print(text[0,:].size())
    print(target[0,:].size())
    print(text[0,:].tolist())
    print(target[0,:].tolist())
    print("１データ目の単語列を表示")
    print([TEXT.vocab.itos[data] for data in  text[0,:].tolist()])
    print([TEXT.vocab.itos[data] for data in  target[0,:].tolist()])
    print("２データ目の単語列を表示")
    print([TEXT.vocab.itos[data] for data in  text[1,:].tolist()])
    print([TEXT.vocab.itos[data] for data in  target[1,:].tolist()])
            
    break

データの形状確認
torch.Size([30, 32])
torch.Size([30, 32])
permuteでバッチを先にする
torch.Size([32, 30])
torch.Size([32, 30])
１データ目の形状とデータを確認
torch.Size([30])
torch.Size([30])
[9971, 9972, 9973, 9975, 9976, 9977, 9981, 9982, 9983, 9984, 9985, 9987, 9988, 9989, 9990, 9992, 9993, 9994, 9995, 9996, 9997, 9998, 9999, 10000, 3, 9257, 0, 4, 73, 394]
[9972, 9973, 9975, 9976, 9977, 9981, 9982, 9983, 9984, 9985, 9987, 9988, 9989, 9990, 9992, 9993, 9994, 9995, 9996, 9997, 9998, 9999, 10000, 3, 9257, 0, 4, 73, 394, 34]
１データ目の単語列を表示
['aer', 'banknote', 'berlitz', 'calloway', 'centrust', 'cluett', 'fromstein', 'gitano', 'guterman', 'hydro-quebec', 'ipo', 'kia', 'memotec', 'mlx', 'nahb', 'punts', 'rake', 'regatta', 'rubens', 'sim', 'snack-food', 'ssangyong', 'swapo', 'wachter', '<eos>', 'pierre', '<unk>', 'N', 'years', 'old']
['banknote', 'berlitz', 'calloway', 'centrust', 'cluett', 'fromstein', 'gitano', 'guterman', 'hydro-quebec', 'ipo', 'kia', 'memotec', 'mlx', 'nahb', 'punts', 'rake', 'regatta', 'rubens', 'sim'

In [0]:
class LstmLangModel(nn.Module):
    def __init__(self, batch_size, hidden_size, vocab_size, embedding_length, weights):
        super(LstmLangModel, self).__init__()
        self.batch_size = batch_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.embed = nn.Embedding(vocab_size, embedding_length)
        self.embed.weight.data.copy_(weights)
        self.lstm = nn.LSTM(embedding_length, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
    
    def forward(self, x, h):
        x = self.embed(x)
        output_seq, (h, c) = self.lstm(x, h)
        # 出力を変形する (batch_size*sequence_length, 隠れ層のユニット数hidden_size)
        out = output_seq.reshape(output_seq.size(0)*output_seq.size(1), output_seq.size(2))
        out = self.fc(out) 
        return out, (h, c)

net = LstmLangModel(batch_size, hidden_size, vocab_size, embedding_length, word_embeddings)
net = net.to(device)


In [0]:
# 損失関数、最適化関数を定義
criterion = nn.CrossEntropyLoss()
optim = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()))

In [0]:
num_epochs = 200
train_loss_list = []

# Truncated backpropagation
# 逆伝播を途中で打ち切る
def detach(states):
    return [state.detach() for state in states] 

for epoch in range(num_epochs):
    train_loss = 0
    # 初期隠れ状態とセル状態を設定する
    states = (torch.zeros(1, batch_size, hidden_size).to(device),
              torch.zeros(1, batch_size, hidden_size).to(device))
    #train
    net.train()
    for i, batch in enumerate(train_iter):
      text = batch.text.to(device)
      labels = batch.target.to(device)
      text = text.permute(1, 0)
      labels = labels.permute(1, 0)
      
      optim.zero_grad()
      states = detach(states)
      outputs, states = net(text, states)
      loss = criterion(outputs, labels.reshape(-1))
      train_loss += loss.item()
      loss.backward()
      optim.step()
    avg_train_loss = train_loss / len(train_iter)
    print ('Epoch [{}/{}], Loss: {loss:.4f}, Perplexity: {perp:5.2f}' 
                   .format(epoch+1, num_epochs, i+1, loss=avg_train_loss, perp=np.exp(avg_train_loss)))
    train_loss_list.append(avg_train_loss)


Epoch [1/200], Loss: 2.8084, Perplexity: 16.58
Epoch [2/200], Loss: 2.7145, Perplexity: 15.10
Epoch [3/200], Loss: 2.6533, Perplexity: 14.20
Epoch [4/200], Loss: 2.6033, Perplexity: 13.51
Epoch [5/200], Loss: 2.5600, Perplexity: 12.94
Epoch [6/200], Loss: 2.5214, Perplexity: 12.45
Epoch [7/200], Loss: 2.4883, Perplexity: 12.04
Epoch [8/200], Loss: 2.4579, Perplexity: 11.68
Epoch [9/200], Loss: 2.4304, Perplexity: 11.36
Epoch [10/200], Loss: 2.4037, Perplexity: 11.06
Epoch [11/200], Loss: 2.3780, Perplexity: 10.78
Epoch [12/200], Loss: 2.3503, Perplexity: 10.49
Epoch [13/200], Loss: 2.3240, Perplexity: 10.22
Epoch [14/200], Loss: 2.2988, Perplexity:  9.96
Epoch [15/200], Loss: 2.2728, Perplexity:  9.71
Epoch [16/200], Loss: 2.2487, Perplexity:  9.48
Epoch [17/200], Loss: 2.2275, Perplexity:  9.28
Epoch [18/200], Loss: 2.2074, Perplexity:  9.09
Epoch [19/200], Loss: 2.1858, Perplexity:  8.90
Epoch [20/200], Loss: 2.1657, Perplexity:  8.72
Epoch [21/200], Loss: 2.1490, Perplexity:  8.58
E

In [19]:
num_samples = 1000     # サンプリングされる単語の数
# モデルをテストする
net.eval()
with torch.no_grad():
    text = ""
    # 初期隠れ状態とセル状態を設定する
    states = (torch.zeros(1, 1, hidden_size).to(device),
              torch.zeros(1, 1, hidden_size).to(device))

    # ランダムに1単語のIDを選択
    input = torch.multinomial(torch.ones(vocab_size), num_samples=1).unsqueeze(1).to(device)
#     print("input word", TEXT.vocab.itos[input])
    
    for i in range(num_samples):
#         print("input word", TEXT.vocab.itos[input])
        
        output, states = net(input, states)
        word_id = output.max(1)[1].item()
        # 次のタイムステップのために単語IDを入力
        input.fill_(word_id)
        # 単語IDから文字を取得
        word = TEXT.vocab.itos[word_id]
        # textに書き込む
        word = '\n' if word == '<eos>' else word + ' '
        text += word

    # textを表示
    print(text)


input word effectively
input word effectively
input word hhs
input word jan
input word sens.
input word boren
input word perestroika
input word republicans
input word are
input word <unk>
input word to
input word <unk>
input word conservative
input word and
input word that
input word the
input word president
input word 's
input word veto
input word power
input word is
input word a
input word <unk>
input word of
input word the
input word preamble
input word <eos>
input word the
input word <unk>
input word <unk>
input word was
input word the
input word <unk>
input word of
input word the
input word <unk>
input word <eos>
input word the
input word <unk>
input word <unk>
input word <unk>
input word in
input word a
input word <unk>
input word <unk>
input word <unk>
input word <unk>
input word N
input word <unk>
input word and
input word that
input word it
input word was
input word <unk>
input word <unk>
input word and
input word <unk>
input word <eos>
input word the
input word <unk>
input wo