## This notebook is used to Debug functionalities

In [1]:
import torch
import torch.nn as nn
import tiktoken
from model import *
from dataset import *
from torch.utils.data import DataLoader

In [2]:
# Get gpt-2 tokenizer
tokenizer = tiktoken.get_encoding("gpt2")
vocab_size = tokenizer.n_vocab
print("Vocab size:", vocab_size)

Vocab size: 50257


## Read data from a test file

In [3]:
with open("../data/the-verdict.txt", "r") as file:
    raw_text = file.read()

print(len(raw_text), "characters")

20479 characters


## Create Dataset object

In [4]:
dataset = GPTDataset(raw_text, tokenizer, 126, 1) # input phrases of 126 tokens

inputs, labels = dataset[0]
print("Inputs shape:", inputs.shape)
print("Labels shape:", labels.shape)

Inputs shape: torch.Size([126])
Labels shape: torch.Size([126])


## Create DataLoader

In [5]:
dataloader = DataLoader(
    dataset = dataset,
    batch_size = 32,
    shuffle = True
)

dataiter = iter(dataloader)
inputs_batch, labels_batch = next(dataiter)

print("Input batch shape:", inputs_batch.shape)
print("Labels batch shape:", labels_batch.shape)

Input batch shape: torch.Size([32, 126])
Labels batch shape: torch.Size([32, 126])


In [6]:
emb_dim = 768
emb = nn.Embedding(vocab_size, emb_dim)
vector = emb(inputs_batch)
vector.shape

torch.Size([32, 126, 768])

In [7]:
args = ModelArgs(
    emb_dim = 768,
    num_heads = 2,
    context_length=126
)

In [8]:
attention = MultiHeadAttention(args)
z=attention(vector)
print("Context vector shape:", z.shape)

torch.Size([32, 2, 126, 126])


tensor([[[ 4.2722e-02, -2.8851e-01,  2.9485e-01,  ..., -2.3117e-01,
           6.9260e-02, -6.1598e-02],
         [ 1.2487e-01, -4.5144e-01, -6.5690e-03,  ..., -2.4244e-01,
           6.6749e-02,  4.6221e-02],
         [ 2.3317e-01, -4.0140e-01,  6.4555e-02,  ..., -8.5278e-02,
          -1.4484e-01, -4.9935e-02],
         ...,
         [ 5.3691e-03, -2.9656e-02,  2.7759e-02,  ..., -1.0381e-02,
          -4.4288e-02, -8.7973e-02],
         [ 3.6611e-02, -4.6819e-02,  2.6386e-02,  ..., -3.0780e-02,
          -6.0035e-02, -4.1340e-02],
         [ 4.5664e-02, -1.5850e-02,  8.4789e-03,  ..., -2.5191e-02,
          -6.1355e-02, -7.1134e-02]],

        [[ 5.7726e-01, -1.4309e-01, -2.5202e-01,  ..., -3.8832e-01,
          -1.4586e-01, -1.4087e-01],
         [ 4.0782e-01,  2.0625e-02, -3.4876e-01,  ..., -4.9718e-01,
           5.5107e-02, -2.7463e-02],
         [ 4.3375e-01, -1.1046e-01, -5.7712e-02,  ..., -2.3475e-02,
          -3.2688e-02, -1.8665e-01],
         ...,
         [ 4.4627e-03, -9