In [5]:
import torch
from d2l import torch as d2l

In [6]:
@d2l.add_to_class(d2l.TimeMachine)  # @save
def __init__(self, batch_size, num_steps, num_train=10000, num_val=5000):
    """
    num_steps: The number of token in a single time-step
    """
    super(d2l.TimeMachine, self).__init__()
    self.save_hyperparameters()
    # Download the Timemachine.txt, tokenize it, then vocabularizes it
    corpus, self.vocab = self.build(self._download())

    # This array contains the list of tokens partitioned into num_step-lengthed tokens. This will permit us to later sample mini-batches
    array = torch.tensor(
        [corpus[i : i + num_steps + 1] for i in range(len(corpus) - num_steps)]
    )
    # Notice that Y is always 1-timestep(corresponding to 1 token) ahead of of X.
    self.X, self.Y = array[:, :-1], array[:, 1:]

In [7]:
@d2l.add_to_class(d2l.TimeMachine)  # @save
def get_dataloader(self, train):
    # This is the list of token indices which will be used to load the dataset in case of training or validation
    # Training_set_size = num_train and the rest will be used for validation until we reach the validation_set_size = num_val
    idx = (
        slice(0, self.num_train)
        if train
        else slice(self.num_train, self.num_train + self.num_val)
    )
    return self.get_tensorloader([self.X, self.Y], train, idx)

In [12]:
# Loading dataset for verification
data = d2l.TimeMachine(batch_size=2, num_steps=10)
for X, Y in data.train_dataloader():
    print("X:", X, "\nY:", Y)
    break

X: tensor([[14,  6, 15, 20, 10, 16, 15, 20,  0,  7],
        [ 0, 20,  4,  2, 21, 21,  6, 19,  6,  5]]) 
Y: tensor([[ 6, 15, 20, 10, 16, 15, 20,  0,  7, 16],
        [20,  4,  2, 21, 21,  6, 19,  6,  5,  0]])


- To train a language model, we randomly sample pairs of of input sequence and target sequences in minibatches, after training, we use the **perplexity** to measure the language model quality

- **Perplexity**: Is the exponentiation of the sum of the information (-log P) carried by succesively getting x_t token using the x_t-1 until 1 token