# The Language Model Dataset

Read minibatches of input sequences and label sequences at random

In [1]:
import collections
import random
import re
import torch
from d2l import torch as d2l

Reads the lines of text into a list

In [2]:
class TimeMachine(d2l.DataModule): 
    def _download(self):
        fname = d2l.download(d2l.DATA_URL+'timemachine.txt', self.root,
                             '090b5e7e70c295757f55df93cb0a180b9691891a')
        with open(fname) as f:
            return f.read()

data = TimeMachine()
raw_text = data._download()
raw_text[:60]

'The Time Machine, by H. G. Wells [1898]\n\n\n\n\nI\n\n\nThe Time Tra'

In [3]:
@d2l.add_to_class(TimeMachine)  
def _preprocess(self, text):
    return re.sub('[^A-Za-z]+', ' ', text).lower()

text = data._preprocess(raw_text)
text[:60]

'the time machine by h g wells i the time traveller for so it'

We then split each line into a list of tokens

In [4]:
@d2l.add_to_class(TimeMachine)  
def _tokenize(self, text):
    return list(text)

tokens = data._tokenize(text)
','.join(tokens[:30])

't,h,e, ,t,i,m,e, ,m,a,c,h,i,n,e, ,b,y, ,h, ,g, ,w,e,l,l,s, '

To this end, we will need a class
to construct a *vocabulary*
that assigns a unique index 
to each distinct token value

In [5]:
class Vocab:  
    """Vocabulary for text."""
    def __init__(self, tokens=[], min_freq=0, reserved_tokens=[]):
        if tokens and isinstance(tokens[0], list):
            tokens = [token for line in tokens for token in line]
        counter = collections.Counter(tokens)
        self.token_freqs = sorted(counter.items(), key=lambda x: x[1],
                                  reverse=True)
        self.idx_to_token = list(sorted(set(['<unk>'] + reserved_tokens + [
            token for token, freq in self.token_freqs if freq >= min_freq])))
        self.token_to_idx = {token: idx
                             for idx, token in enumerate(self.idx_to_token)}

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]

    def to_tokens(self, indices):
        if hasattr(indices, '__len__') and len(indices) > 1:
            return [self.idx_to_token[int(index)] for index in indices]
        return self.idx_to_token[indices]

    @property
    def unk(self):
        return self.token_to_idx['<unk>']

Construct the vocabulary

In [6]:
vocab = Vocab(tokens)
indicies = vocab[tokens[0]]
print('indices:', indicies)
print('words:', vocab.to_tokens(indicies))

indices: 21
words: t


Package everything into the `load_corpus_time_machine` function

In [7]:
@d2l.add_to_class(TimeMachine)  
def build(self, raw_text, vocab=None):
    tokens = self._tokenize(self._preprocess(raw_text))
    if vocab is None: vocab = Vocab(tokens)
    corpus = [vocab[token] for token in tokens]
    return corpus, vocab

corpus, vocab = data.build(raw_text)
len(corpus), len(vocab)

(173428, 28)

In [8]:
@d2l.add_to_class(TimeMachine)  
def __init__(self, batch_size, num_steps, num_train=10000, num_val=5000):
    super(TimeMachine, self).__init__()
    self.save_hyperparameters()
    corpus, self.vocab = self.build(self._download())
    array = torch.tensor([corpus[i:i+num_steps+1]
                        for i in range(0, len(corpus)-num_steps-1)])
    self.X, self.Y = array[:,:-1], array[:,1:]

Random Sampling

In [9]:
@d2l.add_to_class(TimeMachine)  
def get_dataloader(self, train):
    idx = slice(0, self.num_train) if train else slice(
        self.num_train, self.num_train+self.num_val)
    return self.get_tensorloader([self.X, self.Y], train, idx)

Manually generate a sequence from 0 to 34

In [10]:
data = TimeMachine(batch_size=2, num_steps=10)
for X, Y in data.train_dataloader():
    print('X:', X, '\nY:', Y)
    break

X: tensor([[ 0, 21,  9,  6,  0, 13,  2, 21, 21,  6],
        [10, 15,  8,  0,  9, 10, 20,  0, 17,  2]]) 
Y: tensor([[21,  9,  6,  0, 13,  2, 21, 21,  6, 19],
        [15,  8,  0,  9, 10, 20,  0, 17,  2, 21]])
