In [47]:
from u import *
from ut import *

import os
import torch

from collections import Counter


class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []
        self.counter = Counter()
        self.total = 0

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        token_id = self.word2idx[word]
        self.counter[token_id] += 1
        self.total += 1
        return self.word2idx[word]

    def __len__(self):
        return len(self.idx2word)


class Corpus(object):
    def __init__(self, path):
        self.dictionary = Dictionary()
        self.train = self.tokenize(os.path.join(path, 'train.txt'))
        self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
        self.test = self.tokenize(os.path.join(path, 'test.txt'))

    def tokenize(self, path):
        """Tokenizes a text file."""
        assert os.path.exists(path)
        # Add words to the dictionary
        with open(path, 'r') as f:
            tokens = 0
            for line in f:
                words = line.split() + ['<eos>']
                tokens += len(words)
                for word in words:
                    self.dictionary.add_word(word)

        # Tokenize file content
        with open(path, 'r') as f:
            ids = torch.LongTensor(tokens)
            token = 0
            for line in f:
                words = line.split() + ['<eos>']
                for word in words:
                    ids[token] = self.dictionary.word2idx[word]
                    token += 1

        return ids

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [31]:
wik = Data / 'wikitext-103'
vocab = Counter()
for x in 'train', 'valid', 'test':
    text = (wik / x + '.txt').load()
    tokens = text.split(' ')
    tokens.pop() # remove trailing \n
    vocab.update(tokens)
    print(len(tokens))

103227021
217646
245569


In [39]:
tokens = vocab.most_common()

In [49]:
corpus = Corpus(wik)

In [55]:
data = dict(train=from_torch(corpus.train), valid=from_torch(corpus.valid), test=from_torch(corpus.test))

In [58]:
idx2word = np.array(corpus.dictionary.idx2word)

In [61]:
(Cache.mk() / 'vocab.npy').save(idx2word)

In [66]:
for k in data:
    data[k] = data[k].astype(np.int32)

In [68]:
for k in data:
    ((Cache / 'wikitext-103').mk() / k + '.npy').save(data[k])

In [70]:
(Cache / 'wikitext-103/train.npy').load()

array([ 0,  1,  2, ..., 15,  0,  0], dtype=int32)

In [52]:
len(corpus.dictionary.word2idx)

267735