# The Language Model Dataset

Read minibatches of input sequences and label sequences at random

In [1]:
import random
import torch
from d2l import torch as d2l

Random Sampling

In [3]:
@d2l.add_to_class(d2l.TimeMachine)  
def __init__(self, batch_size, num_steps, num_train=10000, num_val=5000):
    super(d2l.TimeMachine, self).__init__()
    self.save_hyperparameters()
    self.prepare_data()

class LMDataLoader(d2l.HyperParameters):  
    def __init__(self, corpus, batch_size, num_steps, train):
        self.save_hyperparameters()
        self.num_batches = (len(corpus) - 1 - (num_steps if train else 0)
                           ) // (self.num_steps * self.batch_size)
    def __len__(self):
        return self.num_batches

    def __iter__(self):
        corpus = (self.corpus[random.randint(0, self.num_steps - 1):]
                  if self.train else self.corpus)
        m = (len(corpus)-1) // self.num_steps
        initial_indices = list(range(0, m*self.num_steps, self.num_steps))
        if self.train:
            random.shuffle(initial_indices)
        for i in range(0, self.num_batches):
            batch_indicies = initial_indices[
                i*self.batch_size : (i+1) * self.batch_size]
            X = [corpus[j : j+self.num_steps] for j in batch_indicies]
            Y = [corpus[j+1 : j+1+self.num_steps] for j in batch_indicies]
            yield torch.tensor(X), torch.tensor(Y)


@d2l.add_to_class(d2l.TimeMachine)  
def get_dataloader(self, train):
    corpus = (self.corpus[: self.num_train] if train else
              self.corpus[self.num_train : self.num_train+self.num_val])
    return LMDataLoader(corpus, self.batch_size, self.num_steps, train)

Manually generate a sequence from 0 to 34

In [4]:
data = d2l.TimeMachine(batch_size=2, num_steps=10)
for X, Y in data.train_dataloader():
    print('X:', X, '\nY:', Y)
    break

X: tensor([[24,  9, 26, 20,  9, 16, 22, 13,  5,  0],
        [ 2, 19,  6, 15, 21,  4, 19, 26, 20, 21]]) 
Y: tensor([[ 9, 26, 20,  9, 16, 22, 13,  5,  0,  9],
        [19,  6, 15, 21,  4, 19, 26, 20, 21,  2]])
