In [1]:
import torch
from torch import nn
import collections,random,re

In [2]:
class Vocab:
    """
    Vocabulary for text.
    """

    def __init__(self, tokens=[], min_freq=0, reserved_tokens=[]) -> None:
        if tokens and isinstance(tokens[0], list):
            tokens = [token for line in tokens for token in line]
        counter = collections.Counter(tokens)
        self.token_freqs = sorted(
            counter.items(), key=lambda x: x[1], reverse=True)
        self.idx_to_token = list(sorted(set(
            ['<unk>']+reserved_tokens + [token for token, freq in self.token_freqs if freq >= min_freq])))
        self.token_to_idx = {token: idx for idx,
                             token in enumerate(self.idx_to_token)}

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]
    
    def to_tokens(self, indices):
        if hasattr(indices,'__len__') and len(indices)>1:
            return [self.idx_to_token[idx] for idx in indices]
        return self.idx_to_token[indices]
    
    
    @property
    def unk(self):
        return self.token_to_idx['<unk>']

In [12]:
from d2l_common import DataModule


class TimeMachine(DataModule):
    def __init__(self, batch_size, num_steps, num_train=10000, num_val=5000, fname='timemachine.txt', root='../data'):
        self.fname = fname
        self.root = root
        self.batch_size = batch_size
        self.num_steps = num_steps
        self.num_train = num_train
        self.num_val = num_val
        corups, vocab = self.build(self._download())
        array = torch.tensor([corups[i:i+num_steps+1]
                             for i in range(len(corups)-num_steps)])
        self.array = array
        self.X, self.Y = array[:, :-1], array[:, 1:]

    def get_dataloader(self, train):
        idx = slice(0, self.num_train) if train else slice(
            self.num_train, self.num_train+self.num_val)
        return self.get_tensorloader([self.X, self.Y], idx)

    def _download(self):
        with open(self.root+'/'+self.fname) as f:
            return f.read()

    def _preprocess(self, text):
        return re.sub('[^A-Za-z]+', ' ', text).lower()

    def _tokenize(self, text):
        return list(text)

    def build(self, raw_txt, vocab=None):
        tokens = self._tokenize(self._preprocess(raw_txt))
        if vocab is None:
            vocab = Vocab(tokens)
        corups = [vocab[token] for token in tokens]
        return corups, vocab


# data = TimeMachine('timemachine.txt')
# raw_txt = data._download()
# raw_txt[:60]

In [21]:
data = TimeMachine(batch_size=2, num_steps=10)
print(data.array[:10])
print(data.array.shape)
print(data.X.shape)
print(data.Y.shape)
for X, Y in data.train_dataloader():
    print('X:', X.shape, '\nY:', Y.shape)
    break

tensor([[ 0, 21,  9,  6,  0, 17, 19, 16, 11,  6,  4],
        [21,  9,  6,  0, 17, 19, 16, 11,  6,  4, 21],
        [ 9,  6,  0, 17, 19, 16, 11,  6,  4, 21,  0],
        [ 6,  0, 17, 19, 16, 11,  6,  4, 21,  0,  8],
        [ 0, 17, 19, 16, 11,  6,  4, 21,  0,  8, 22],
        [17, 19, 16, 11,  6,  4, 21,  0,  8, 22, 21],
        [19, 16, 11,  6,  4, 21,  0,  8, 22, 21,  6],
        [16, 11,  6,  4, 21,  0,  8, 22, 21,  6, 15],
        [11,  6,  4, 21,  0,  8, 22, 21,  6, 15,  3],
        [ 6,  4, 21,  0,  8, 22, 21,  6, 15,  3,  6]])
torch.Size([192747, 11])
torch.Size([192747, 10])
torch.Size([192747, 10])
X: torch.Size([1, 10]) 
Y: torch.Size([1, 10])
