# 🦖 Data Loader

#### 📚 Libraries
Import libraries and configure the environment.

In [34]:
import json
import torch

#### 📂 Data
Load the data from local.

In [35]:
train_data = torch.load("data/train_data.pt")
val_data = torch.load("data/val_data.pt")

In [36]:
with open("data/encoder_dict.json", "r") as f:
    encoder_dict = json.load(f)

stoi = encoder_dict["stoi"]
itos = encoder_dict["itos"]
itos = {int(k): v for k, v in encoder_dict["itos"].items()}

In [37]:
def decode(integers: list, itos: dict = itos) -> str:
    """Decode list of integers to text."""
    return "".join([itos[i] for i in integers])

#### 🦮 Batches of data

**Test the data loader**

In [38]:
torch.manual_seed(42)

batch_size = 4  # how many samples to process at once
block_size = 8  # the context length


def get_batch(split):
    """Generates a small batch of data of inputs x and targets y"""
    data = train_data if split == "train" else val_data
    # starting index of each sequence
    ix = torch.randint(len(data) - block_size, (batch_size,))
    # the sequences
    x = torch.stack([data[i : i + block_size] for i in ix])
    # the sequences shifted by 1
    y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
    return x, y

In [39]:
xb, yb = get_batch("train")

print("Inputs")
print(xb.shape)
print(xb)
print()
print("Targets")
print(yb.shape)
print(yb)

Inputs
torch.Size([4, 8])
tensor([[85, 72, 74, 76, 86, 87, 85, 82],
        [88, 81,  9, 86, 88, 86,  9, 70],
        [86, 72, 81, 68, 79, 68, 85,  9],
        [83, 72, 85, 86, 82, 81, 68, 79]])

Targets
torch.Size([4, 8])
tensor([[72, 74, 76, 86, 87, 85, 82,  9],
        [81,  9, 86, 88, 86,  9, 70, 82],
        [72, 81, 68, 79, 68, 85,  9, 84],
        [72, 85, 86, 82, 81, 68, 79,  9]])


In [40]:
for batch in range(batch_size):
    print(f"Batch {batch+1}")
    for time in range(block_size):
        context = xb[batch, : time + 1]
        target = yb[batch, time]
        print(f"{time+1}: {context.tolist()} -> {target}")
    print("-" * 10)

Batch 1
1: [85] -> 72
2: [85, 72] -> 74
3: [85, 72, 74] -> 76
4: [85, 72, 74, 76] -> 86
5: [85, 72, 74, 76, 86] -> 87
6: [85, 72, 74, 76, 86, 87] -> 85
7: [85, 72, 74, 76, 86, 87, 85] -> 82
8: [85, 72, 74, 76, 86, 87, 85, 82] -> 9
----------
Batch 2
1: [88] -> 81
2: [88, 81] -> 9
3: [88, 81, 9] -> 86
4: [88, 81, 9, 86] -> 88
5: [88, 81, 9, 86, 88] -> 86
6: [88, 81, 9, 86, 88, 86] -> 9
7: [88, 81, 9, 86, 88, 86, 9] -> 70
8: [88, 81, 9, 86, 88, 86, 9, 70] -> 82
----------
Batch 3
1: [86] -> 72
2: [86, 72] -> 81
3: [86, 72, 81] -> 68
4: [86, 72, 81, 68] -> 79
5: [86, 72, 81, 68, 79] -> 68
6: [86, 72, 81, 68, 79, 68] -> 85
7: [86, 72, 81, 68, 79, 68, 85] -> 9
8: [86, 72, 81, 68, 79, 68, 85, 9] -> 84
----------
Batch 4
1: [83] -> 72
2: [83, 72] -> 85
3: [83, 72, 85] -> 86
4: [83, 72, 85, 86] -> 82
5: [83, 72, 85, 86, 82] -> 81
6: [83, 72, 85, 86, 82, 81] -> 68
7: [83, 72, 85, 86, 82, 81, 68] -> 79
8: [83, 72, 85, 86, 82, 81, 68, 79] -> 9
----------


In [41]:
for batch in range(batch_size):
    print(f"Batch {batch+1}")
    for time in range(block_size):
        context = decode(xb[batch, : time + 1].tolist())
        target = decode([int(yb[batch, time])])
        print(f"{time+1}: {context} -> {target}")
    print("-" * 10)

Batch 1
1: r -> e
2: re -> g
3: reg -> i
4: regi -> s
5: regis -> t
6: regist -> r
7: registr -> o
8: registro ->  
----------
Batch 2
1: u -> n
2: un ->  
3: un  -> s
4: un s -> u
5: un su -> s
6: un sus ->  
7: un sus  -> c
8: un sus c -> o
----------
Batch 3
1: s -> e
2: se -> n
3: sen -> a
4: sena -> l
5: senal -> a
6: senala -> r
7: senalar ->  
8: senalar  -> q
----------
Batch 4
1: p -> e
2: pe -> r
3: per -> s
4: pers -> o
5: perso -> n
6: person -> a
7: persona -> l
8: personal ->  
----------
