In [37]:
import torch
import tiktoken

from utils import generate_text, text_to_token_ids, token_ids_to_text
from model_architecture import GPTModel

In [38]:
GPT_CONFIG_SMALL = {
        "vocab_size": 50257,
        "context_length": 512,
        "emb_dim": 512,
        "num_heads": 16,
        "n_layers": 2,
        "dropout": 0.1,
        "qkv_bias": False
    }

model = GPTModel(GPT_CONFIG_SMALL)

In [39]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Inference on GPU")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Inference on MPS")
else:
    device = torch.device("cpu")
    print("Inference on CPU")

Inference on MPS


In [40]:
model_location = '../saved_model/model.pth'
checkpoint = torch.load(model_location, map_location=device, weights_only=True)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

GPTModel(
  (token_embedding): Embedding(50257, 512)
  (positional_embedding): Embedding(512, 512)
  (drop_emb): Dropout(p=0.1, inplace=False)
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (attention): MultiHeadAttention(
        (W_query): Linear(in_features=512, out_features=512, bias=False)
        (W_key): Linear(in_features=512, out_features=512, bias=False)
        (W_value): Linear(in_features=512, out_features=512, bias=False)
        (out_proj): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (feedforward): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): GELU()
          (2): Linear(in_features=2048, out_features=512, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (dropout_shortcut): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (attention): Mult

In [41]:
total_params = sum(p.numel() for p in model.parameters())
print(f'Your model size is: {total_params} parameters.')

Your model size is: 58028032 parameters.


In [42]:
tokenizer = tiktoken.get_encoding("gpt2")
#torch.manual_seed(100)

In [43]:
input_prompt = " "

In [61]:
token_ids = generate_text(
    model=model,
    idx=text_to_token_ids(input_prompt, tokenizer),
    max_new_tokens=50,
    context_len=GPT_CONFIG_SMALL["context_length"],
    top_k=35,
    temperature=0.7
)

In [62]:
output = token_ids_to_text(token_ids, tokenizer)
print(output)

 

"Please, let's try to fix the leak?" Tom asked.

"No, it is not a friend. It is mine. You can do it. You have to use it to make it better. You need to take
