In [None]:
# --- Imports ---
import torch
import matplotlib.pyplot as plt
import sys
sys.path.append('../src')
import bigram
import os

In [None]:
# --- Load Data ---
with open('../data/input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [None]:
# --- Train Model ---
if os.path.exists('../model.pth'):
    # Load the model
    print("Loading the pre-trained model")
    model, encode, decode, block_size = bigram.load_bigram_model('../model.pth')

else:
    model, encode, decode = bigram.train_bigram(text)

In [None]:
# --- Generate text from the model ---

# Generate text starting from an empty context
print("Generating text starting from an empty context")
print(bigram.generate_text(model, encode, decode, max_new_tokens=200))


print("Generating text starting from 'T'")
generated_text = bigram.generate_text(model, encode, decode, start_str="T", max_new_tokens=200)
print(generated_text)

print("Generating text starting from 'This'")
generated_text = bigram.generate_text(model, encode, decode, start_str="This", max_new_tokens=200)
print(generated_text)

In [None]:
# --- Visualize the embedding table (Optional, but insightful) ---
# This part is useful for visualizing what the model has learned,
# although with the bigram model, it's just a direct mapping.
try: #added error handling
    plt.figure(figsize=(8, 8))
    plt.imshow(model.token_embedding_table.weight.detach().numpy())
    plt.title("Token Embedding Table")
    plt.xlabel("Token Index (Next Token)")
    plt.ylabel("Token Index (Current Token)")
    plt.colorbar(label="Logit Value")

    # Add labels to the axes (if the vocabulary is small enough)
    chars = sorted(list(set(text)))
    itos = {i: ch for i, ch in enumerate(chars)}
    if len(chars) < 50:  # Avoid cluttering the plot if vocab is too large
        plt.xticks(range(len(chars)), [itos[i] for i in range(len(chars))])
        plt.yticks(range(len(chars)), [itos[i] for i in range(len(chars))])
    plt.show()
except ValueError as e:
    print(f"An error occurred during visualization: {e}")
    print("This may happen if Matplotlib is not properly configured, or if you're running in an environment without a display.")