<a href="https://colab.research.google.com/github/hyngon90/StatQuestTutorial/blob/main/05_StatQuest_Tutorial_GPT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install lightning
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader

import lightning as L

Collecting lightning
  Downloading lightning-2.5.0.post0-py3-none-any.whl.metadata (40 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/40.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.4/40.4 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning)
  Downloading lightning_utilities-0.11.9-py3-none-any.whl.metadata (5.2 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.6.1-py3-none-any.whl.metadata (21 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.5.0.post0-py3-none-any.whl.metadata (21 kB)
Downloading lightning-2.5.0.post0-py3-none-any.whl (815 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m815.2/815.2 kB[0m [31m33.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.11.9-py3-none-any.whl (28 kB)
Downloading torchmetrics-1.6.1-py3-none-any

In [3]:
'''
What is StatQuest? -> Awesome!
StatQuest is what? -> Awesome!

There are four + one elements of Token
Token = {"what","is", "StatQuest", "awesome", "<EOS>"}
'''

token_to_id = {
    'what': 0,
    'is': 1,
    'statquest':2,
    'awesome':3,
    '<EOS>':4,
}
id_to_token = dict(map(reversed, token_to_id.items()))

In [4]:
'''
What is StatQuest? -> Awesome!
StatQuest is what? -> Awesome!
'''

inputs = torch.tensor([
    [token_to_id['what'], token_to_id['is'], token_to_id['statquest'], token_to_id['<EOS>'], token_to_id['awesome']],
    [token_to_id['statquest'], token_to_id['is'], token_to_id['what'], token_to_id['<EOS>'], token_to_id['awesome']],

])
labels = torch.tensor([
    [token_to_id['is'], token_to_id['statquest'], token_to_id['<EOS>'], token_to_id['awesome'], token_to_id['<EOS>']],
    [token_to_id['is'], token_to_id['what'], token_to_id['<EOS>'], token_to_id['awesome'], token_to_id['<EOS>']],
])

dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)

In [5]:
class PositionEncoding(nn.Module):
  def __init__(self, d_model = 2, max_len = 6):
    super().__init__()

    pe = torch.zeros(max_len, d_model)

    pos = torch.arange(start = 0, end = max_len, step=1).float().unsqueeze(1)
    idx = torch.arange(start = 0, end = d_model, step=2).float()

    div = 1/torch.tensor(10000.0)**(idx / d_model)

    pe[:, 0::2] = torch.sin(pos * div)
    pe[:, 1::2] = torch.cos(pos * div)

    self.register_buffer('pe', pe)

  def forward(self, word_embeddings):
    return word_embeddings + self.pe[:word_embeddings.size(0), :]

In [6]:
class Attention(nn.Module):
  def __init__(self, d_model = 2):
    super().__init__()

    self.row_dim = 0
    self.col_dim = 1

    self.W_q = nn.Linear(in_features = d_model, out_features = d_model, bias = False)
    self.W_k = nn.Linear(in_features = d_model, out_features = d_model, bias = False)
    self.W_v = nn.Linear(in_features = d_model, out_features = d_model, bias = False)

  def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask = None):
    q = self.W_q(encodings_for_q)
    k = self.W_k(encodings_for_k)
    v = self.W_v(encodings_for_v)

    sims = torch.matmul(q, k.transpose(dim0 = self.row_dim , dim1=self.col_dim))
    scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

    if mask is not None:
      scaled_sims = scaled_sims.masked_fill(mask = mask, value=-1e9)

    attention_percents = F.softmax(scaled_sims, dim = self.col_dim)
    attention_score = torch.matmul(attention_percents, v)

    return attention_score

In [7]:
class DecoderOnlyTransformer(L.LightningModule):
  def __init__(self, num_tokens = 4, d_model = 2, max_len = 6):
    super().__init__()

    self.we = nn.Embedding(num_embeddings=num_tokens, embedding_dim=d_model)

    self.pe = PositionEncoding(d_model = d_model, max_len = max_len)

    self.self_attention = Attention(d_model = d_model)

    self.fc_layer = nn.Linear(in_features=d_model, out_features=num_tokens)

    self.loss = nn.CrossEntropyLoss()

  def forward(self, token_ids):

    word_embeddings = self.we(token_ids)
    print(token_ids)
    print(self.we)
    print(word_embeddings)
    position_encoded = self.pe(word_embeddings)

    mask = torch.tril(torch.ones((token_ids.size(dim=0), token_ids.size(dim=0))))
    mask = mask == 0

    self_attention_values = self.self_attention(position_encoded,position_encoded,position_encoded,mask=mask)
    residual_connection_values = position_encoded + self_attention_values

    output = self.fc_layer(residual_connection_values)

    return output

  def configure_optimizers(self):
    return Adam(self.parameters(), lr=0.1)

  def training_step(self, batch, batch_idx):
    inputs, labels = batch
    output = self.forward(inputs[0])
    loss = self.loss(output,labels[0])

    return loss


In [8]:
model = DecoderOnlyTransformer(num_tokens=len(token_to_id), d_model = 2, max_len = 6)

trainer = L.Trainer(max_epochs=30)
trainer.fit(model=model, train_dataloaders=dataloader)

model_input = torch.tensor([
    token_to_id["what"],
    token_to_id["is"],
    token_to_id["statquest"],
    token_to_id["<EOS>"],
])

# model_input = torch.tensor([
#     token_to_id["statquest"],
#     # token_to_id["is"],
#     # token_to_id["what"],
#     token_to_id["<EOS>"],
# ])


input_length = model_input.size(dim=0)

predictions = model(model_input)
predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
predicted_ids = predicted_id

max_length = 6

for i in range(input_length, max_length):
  if(predicted_id == token_to_id["<EOS>"]):
    break

    model_input = torch.cat((model_input, predicted_id))

    predictions = model(model_input)
    predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
    predicted_ids = torch.cat((predicted_ids, predicted_id))

print("Predicted Tokens: \n")
for id in predicted_ids:
  print("\t", id_to_token[id.item()])

INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: 
  | Name           | Type             | Params | Mode 
------------------------------------------------------------
0 | we             | Embedding        | 10     | train
1 | pe             | PositionEncoding | 0      | train
2 | self_attention | Attention        | 12     | train
3 | fc_layer       | Linear           | 15     | train
4 | loss           | CrossEntropyLoss | 0      | train
------------------------------------------------------------
37        Trainable params
0         Non-trainable params
37        Total params
0.000     Total estimated model params size (MB)
8         Modules in train mode
0         

Training: |          | 0/? [00:00<?, ?it/s]

tensor([0, 1, 2, 4, 3])
Embedding(5, 2)
tensor([[ 0.4670,  1.6469],
        [ 0.3186,  1.0466],
        [ 0.3999, -0.2383],
        [ 0.3520,  1.1040],
        [ 0.3112, -1.2546]], grad_fn=<EmbeddingBackward0>)
tensor([2, 1, 0, 4, 3])
Embedding(5, 2)
tensor([[ 0.2999, -0.1383],
        [ 0.4186,  1.1466],
        [ 0.5670,  1.5469],
        [ 0.2520,  1.0040],
        [ 0.2112, -1.1546]], grad_fn=<EmbeddingBackward0>)
tensor([0, 1, 2, 4, 3])
Embedding(5, 2)
tensor([[ 0.6300,  1.5021],
        [ 0.3471,  1.1237],
        [ 0.2623, -0.1099],
        [ 0.1577,  0.9057],
        [ 0.2229, -1.0586]], grad_fn=<EmbeddingBackward0>)
tensor([2, 1, 0, 4, 3])
Embedding(5, 2)
tensor([[ 0.2516, -0.0672],
        [ 0.2926,  1.1177],
        [ 0.7076,  1.4444],
        [ 0.0810,  0.8131],
        [ 0.2755, -0.9672]], grad_fn=<EmbeddingBackward0>)
tensor([0, 1, 2, 4, 3])
Embedding(5, 2)
tensor([[ 0.7835,  1.4009],
        [ 0.2231,  1.0960],
        [ 0.2624, -0.0421],
        [ 0.0259,  0.7322],
    

INFO: `Trainer.fit` stopped: `max_epochs=30` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=30` reached.


tensor([0, 1, 2, 4, 3])
Embedding(5, 2)
tensor([[ 2.5904,  1.7040],
        [-2.7101,  2.2561],
        [ 3.4925, -0.4738],
        [-0.1533,  0.4951],
        [ 3.2677, -1.8998]], grad_fn=<EmbeddingBackward0>)
tensor([2, 1, 0, 4, 3])
Embedding(5, 2)
tensor([[ 3.5152, -0.4872],
        [-2.7328,  2.2582],
        [ 2.6121,  1.7128],
        [-0.1524,  0.4991],
        [ 3.2808, -1.9045]], grad_fn=<EmbeddingBackward0>)
tensor([0, 1, 2, 4, 3])
Embedding(5, 2)
tensor([[ 2.6320,  1.7211],
        [-2.7536,  2.2606],
        [ 3.5373, -0.5008],
        [-0.1534,  0.5041],
        [ 3.2929, -1.9088]], grad_fn=<EmbeddingBackward0>)
tensor([2, 1, 0, 4, 3])
Embedding(5, 2)
tensor([[ 3.5574, -0.5132],
        [-2.7745,  2.2646],
        [ 2.6555,  1.7257],
        [-0.1540,  0.5085],
        [ 3.3042, -1.9127]], grad_fn=<EmbeddingBackward0>)
tensor([0, 1, 2, 4])
Embedding(5, 2)
tensor([[ 2.6771,  1.7302],
        [-2.7932,  2.2686],
        [ 3.5773, -0.5265],
        [-0.1562,  0.5138]], grad_f