In [6]:
import os
import numpy as np

In [7]:
def open_file(filename, mode='r'):
    """
    Commonly used file reader and writer, change this to switch between python2 and python3.
    :param filename: filename
    :param mode: 'r' and 'w' for read and write respectively
    """
    return open(filename, mode, encoding='utf-8', errors='ignore')

In [8]:
class Corpus(object):
    def __init__(self, train_dir, vocab_dir):
        assert os.path.exists(train_dir)
        
        data = list(open_file(train_dir).read().replace('\n', ''))
        
        if not os.path.exists(vocab_dir):
            self._build_vocab(data, vocab_dir)
        
        self.words = open_file(vocab_dir).read().strip().split('\n')
        self.word_to_id = dict(zip(self.words, range(len(self.words))))
        
        data = [self.word_to_id[x] for x in data if x in self.word_to_id]
        self.data = np.array(data)
        
    def _build_vocab(self, data, vocab_dir):
        count_pairs = Counter(data).most_common()
        words, _ = list(zip(*count_pairs))
        open_file(vocab_dir, 'w').write('\n'.join(words) + '\n')
        
    def to_word(self, ids):
        return list(map(lambda x: self.words[x], ids))
        
    def __repr__(self):
        return 'Corpus length: %d, Vocabulary size: %d.' % (len(self.data), len(self.words))

In [15]:
class LMDataset(object):
    def __init__(self, raw_data, batch_size, seq_len):
        num_batch = len(raw_data) // (batch_size * seq_len)
        
        data = raw_data[:(num_batch * batch_size * seq_len)]
        data = data.reshape(num_batch, batch_size, -1).swapaxes(1, 2)
        
        target = raw_data[1:(num_batch * batch_size * seq_len + 1)]
        target = target.reshape(num_batch, batch_size, -1).swapaxes(1, 2)
        
        self.data = data
        self.target = target
        
    def __getitem__(self, index):
        return self.data[index], self.target[index]
    
    def __len__(self):
        return len(self.data)

In [10]:
train_dir = 'data/weicheng.txt'
vocab_dir = 'data/weicheng.vocab.txt'

corpus = Corpus(train_dir, vocab_dir)
print(corpus)

Corpus length: 242052, Vocabulary size: 3423.


In [17]:
train_data = LMDataset(corpus.data, 10, 30)