In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

from gpt import Head, MultiHeadAttention, FeedForward, Block, GPTLanguageModel

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

model_path = '../data/models/gpt_simon_1000.pth'  # Update with your model path
checkpoint = torch.load(model_path, map_location=device)

vocab_size = checkpoint['vocab_size']
stoi = checkpoint['stoi']
itos = checkpoint['itos']
n_embd = checkpoint['n_embd']
n_head = checkpoint['n_head']
n_layer = checkpoint['n_layer']
block_size = checkpoint['block_size']
dropout = checkpoint['dropout']

model = GPTLanguageModel(
        vocab_size=vocab_size,
        n_embd=n_embd,
        n_head=n_head,
        n_layer=n_layer,
        block_size=block_size,
        dropout=dropout,
        device=device,
    ).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

encode = lambda s: [stoi.get(c, 0) for c in s] 
decode = lambda l: ''.join([itos.get(i, '') for i in l])

def generate_text(prompt, max_new_tokens=100):
    idx = torch.tensor([encode(prompt)], dtype=torch.long).to(device)
    idx_new = model.generate(idx, max_new_tokens)
    generated_text = decode(idx_new[0].tolist())
    return generated_text

prompt = "L"
generated = generate_text(prompt, max_new_tokens=200)
print(generated)

def evaluate_model(test_text):
    data = torch.tensor(encode(test_text), dtype=torch.long).to(device)
    block_size = model.position_embedding_table.num_embeddings
    losses = []
    model.eval()
    with torch.no_grad():
        for i in range(0, data.size(0) - block_size, block_size):
            x = data[i:i+block_size].unsqueeze(0)
            y = data[i+1:i+block_size+1].unsqueeze(0)
            x, y = x.to(device), y.to(device)
            _, loss = model(x, y)
            losses.append(loss.item())
    avg_loss = sum(losses) / len(losses)
    print(f"Average loss on test dataset: {avg_loss:.4f}")

# If you have a test dataset
# test_text = open('/path/to/test.txt', 'r', encoding='utf-8').read()
# evaluate_model(test_text)


In [None]:
generated = generate_text(prompt, max_new_tokens=1500)
print(generated)