Creating token embeddings

In [27]:
import sys
import os

sys.path.append(os.path.abspath(".."))

In [28]:
import torch
from Creating_Input_Output_Pairs import create_dataloader

In [25]:
with open('../Tokenization/alice_in_wonderland.txt') as f:
    raw_text = f.read()
print(raw_text[:100])

TITLE: Alice's Adventures in Wonderland
AUTHOR: Lewis Carroll


= CHAPTER I = 
=( Down the Rabbit-Ho


In [29]:
vocab_size = 50257
output_dim = 256

token_embedding_layer = torch.nn.Embedding(vocab_size,output_dim)

In [30]:
max_length=4
dataloader = create_dataloader(
    raw_text,
    batch_size=8,
    max_length=max_length,
    stride=max_length,
    shuffle=False,
)
data_iter = iter(dataloader)
inputs,targets = next(data_iter)

In [32]:
print("\nToken ids:")
print(inputs)


Token ids:
tensor([[49560,  2538,    25, 14862],
        [  338, 15640,   287, 42713],
        [  198,    32, 24318,  1581],
        [   25, 10174, 21298,   628],
        [  198,    28,  5870, 29485],
        [  314,   796,   220,   198],
        [16193,  5588,   262, 25498],
        [   12,    39,  2305,  1267]])


Finding the Vector Embeddings for these token ids

The token_embedding_layer transforms each token id to a 256 dimensional space : the vector embedding for that token

In [35]:
token_embeddings = token_embedding_layer(inputs)
print(token_embeddings.shape)

torch.Size([8, 4, 256])


Adding positional encoding

In [36]:
context_length = max_length
output_dim = 256
'''A 4 x 256 vector in our case.
   For a single batch context-length is 4 i.e. 4 tokens comprise a single batch
   Each token represented as a 256 dimensional vector
'''
position_embedding_layer = torch.nn.Embedding(context_length,output_dim)

In [37]:
pos_embeddings = position_embedding_layer(torch.arange(max_length))
print(pos_embeddings.shape)

torch.Size([4, 256])


final_embedding is the embedding of the first batch of data (batch_size=8) with context_length=4 for the alice_in_wonderland text

In [39]:
final_embedding = token_embeddings + pos_embeddings
print(final_embedding.shape)

torch.Size([8, 4, 256])
