In [19]:
import os
import collections

import torchtext

In [20]:
os.makedirs('./data', exist_ok=True)
# dataディレクトリにAG_NEWSデータセットをダウンロード
train_dataset, test_dataset = torchtext.datasets.AG_NEWS(root='./data')
classes = ['World', 'Sports', 'Business', 'Sci/Tech']

In [21]:
train_dataset = list(train_dataset)
test_dataset = list(test_dataset)

In [22]:
# tokenizer: 文章を単語に分割する
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

In [23]:
# first_sentence = train_dataset[0][1]
# f_tokens = tokenizer(first_sentence)
# 29 ['wall', 'st', '.', 'bears', 'claw', 'back', 'into', 'the', 'black', '(', 'reuters', ')', 'reuters', '-', 'short-sellers', ',', 'wall', 'street', "'", 's', 'dwindling\\band', 'of', 'ultra-cynics', ',', 'are', 'seeing', 'green', 'again', '.']

In [24]:
# データセット内の各単語の出現回数をカウント
counter = collections.Counter()
for _, line in train_dataset:
    counter.update(tokenizer(line))

In [25]:
# vocab: データセットに出現する単語を単語IDとして持つ。
vocab = torchtext.vocab.vocab(counter, min_freq=1)
vocab_size = len(vocab)

In [26]:
# word_lookup = [list((vocab[w], w)) for w in f_tokens]
# [[0, 'wall'], [1, 'st'], [2, '.'], [3, 'bears'], [4, 'claw'], [5, 'back'], [6, 'into'], [7, 'the'], [8, 'black'], [9, '('], [10, 'reuters'], [11, ')'], [10, 'reuters'], [12, '-'], [13, 'short-sellers'], [14, ','], [0, 'wall'], [15, 'street'], [16, "'"], [17, 's'], [18, 'dwindling\\band'], [19, 'of'], [20, 'ultra-cynics'], [14, ','], [21, 'are'], [22, 'seeing'], [23, 'green'], [24, 'again'], [2, '.']]

In [27]:
def encode(x):
    return [vocab.get_stoi()[s] for s in tokenizer(x)]

# 文章中の単語を単語IDに変換
# vec = encode(first_sentence)
# wall -> 0
# st -> 1
# . -> 2
# bears -> 3
# claw -> 4
# back -> 5

In [28]:
def decode(x):
    return [vocab.get_itos()[i] for i in x]

# 単語IDを単語に変換
# decoded_vec = decode(vec)
# 0 -> wall
# 1 -> st
# 2 -> .
# 3 -> bears
# 4 -> claw
# 5 -> back

In [29]:
from torchtext.data.utils import ngrams_iterator

# 'hot dog'という単語列の意味と'hot'と'dog'の単語の意味は意味が全く異なる。
# そこで、bivocab(N-grameでn=2)が全ての単語のペアを格納する

def create_bi_vocab(dataset):
    bi_counter = collections.Counter()
    for _, line in dataset:
        bi_counter.update(ngrams_iterator(tokenizer(line), ngrams=2))
    bi_vocab = torchtext.vocab.vocab(bi_counter, min_freq=2)
    return bi_vocab

In [None]:
# 単語IDに変換
def encode_ngram(texts, vocab):
    bi_vocabs = []
    for s in tokenizer(texts):
        if counter[s] == 1:
            continue
        bi_vocabs.append(vocab.get_stoi()[s])
    return bi_vocabs