In [1]:
from MrML import Tokenizer, Embedder
from vocab import vocab, vocab_len, PAD, SOS, EOS

BATCH_SIZE = 4

# Input size is 3n,
# For efficiency make seq_len a multiple of 8
# 3 * 8 * 4 means our max n can be 8 * 4 (32)
SEQ_LEN = 96 

# d_model doesn't need to bee to large because our tokens are just single
# characters and don't have complex meanings
D_MODEL = 128

# The tokenizer converts text to a tensor of token int values
tokenizer = Tokenizer(vocab=vocab, PAD=PAD, SOS=SOS, EOS=EOS)

# The embedder converts token sequences to vectorized embeddings
embedder = Embedder(seq_len=SEQ_LEN, d_model=D_MODEL, vocab_len=vocab_len)

# Some example prompts to demonstrate that the tokenizer and embedder work
example_prompts = [
    "aaaaabbbaaa",
    "ccdddcc",
    "eeeeeeefffffff"
]

# Print a title for the program
print("Example tokenization and embedding of prompts")

# Print each example prompt's number, text, tokens, and embedding
for i, prompt in enumerate(example_prompts):
    print(f"\n{i + 1}. {prompt}")
    
    tokens = tokenizer.tokenize(prompt)
    print("\nTokens:", tokens.size(), tokens)

    embeddings = embedder.embed(tokens, PAD=PAD, window_size=SEQ_LEN)
    print("\nEmbeddings:", embeddings.size(), embeddings)


Example tokenization and embedding of prompts

1. aaaaabbbaaa

Tokens: torch.Size([11]) tensor([10, 10, 10, 10, 10, 11, 11, 11, 10, 10, 10], dtype=torch.int32)

Embeddings: torch.Size([96, 128]) tensor([[ 1.1094,  1.2869,  1.5717,  ...,  0.8383,  0.4019,  0.9553],
        [ 1.9509,  0.8272,  2.3334,  ...,  0.8383,  0.4020,  0.9553],
        [ 2.0187, -0.1292,  2.5588,  ...,  0.8383,  0.4021,  0.9553],
        ...,
        [ 0.0859, -1.6921,  1.7134,  ...,  0.8634, -0.7830,  1.5570],
        [ 0.7889, -1.0401,  2.3477,  ...,  0.8634, -0.7829,  1.5570],
        [ 1.7175, -1.2794,  3.1772,  ...,  0.8634, -0.7827,  1.5570]])

2. ccdddcc

Tokens: torch.Size([7]) tensor([12, 12, 13, 13, 13, 12, 12], dtype=torch.int32)

Embeddings: torch.Size([96, 128]) tensor([[ 1.4990,  2.5824, -0.1531,  ...,  0.9696,  0.0856,  0.3651],
        [ 2.3405,  2.1227,  0.6086,  ...,  0.9696,  0.0857,  0.3651],
        [ 0.7572, -0.3391, -0.0713,  ...,  2.0113,  1.1177,  2.1963],
        ...,
        [ 0.0859, -1