In [None]:
from itertools import chain
import torch
from shared import (
    gpt,
    tokenizer,
    show_token_mapping,
    demo_embedding_table as embedding_table,
)
from functools import partial

show = partial(print, sep='\n', end='\n\n')

In [None]:
gpt

## Vocabulary

Neural networks deal in numbers, not language

### Tokens

In [None]:
tokenizer

In [None]:
show_token_mapping(
    'token->id',
    tokenizer,
    data='Tokenizers convert text to integer IDs the model can understand, breaking words or subwords into consistent units. Numbers are weird: 12345',
)

In [None]:
show_token_mapping(
    'id->token', tokenizer, data=chain(range(0, 5), range(3000, 3005), range(40473, 40478))
)

In [None]:
show(tokenizer.vocab_size, gpt.transformer.tokens_embed, gpt.transformer.positions_embed)

### Embeddings

#### What are they?

`King - Man + Woman = Queen` (shout out Word2Vec)

![word2vec](assets/word2vec.png)

Learned vector representations where magnitude and direction have meaning

In [None]:
token_embedding = gpt.transformer.tokens_embed.weight[3001]
show(token_embedding.shape, token_embedding, sep='\n')

### H_0

![gpt1math](assets/h0math.png)

In [None]:
tokens = tokenizer('G P T', return_tensors='pt')
show(tokens.input_ids, tokens.input_ids.shape)

##### How does a matrix multiply get us the embeddings?

In [None]:
embedding_table

In [None]:
# get 10s, 30s, 50s

token_indicies = torch.tensor([[1, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 1]])

token_indicies @ embedding_table  # '3 5' @ '5 10' => '3 10'

In [None]:
input_ids = tokens.input_ids
input_shape = input_ids.shape
show('input_ids:', input_shape, input_ids)

position_ids = gpt.transformer.position_ids[None, : input_shape[-1]]
show('position_ids:', position_ids.shape, position_ids)

inputs_embeds = gpt.transformer.tokens_embed(input_ids)
show('inputs_embeds:', inputs_embeds.shape, inputs_embeds)

position_embeds = gpt.transformer.positions_embed(position_ids)
show('position_embeds:', position_embeds.shape, position_embeds)

h_0 = inputs_embeds + position_embeds
show('h_0:', h_0.shape, h_0)