In [None]:
# --- Imports ---
import torch
import matplotlib.pyplot as plt
import sys
import os

# Add the utils directory to the Python path
# We go up to the project root, then down into utils
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'utils'))

# Import the utility modules
import data_utils
import model_utils
import eval_utils

# Add src to path
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src'))
import bigram  # Import the bigram module from the src directory

In [None]:
# --- Load and Prepare Data ---
# 1. Load the text data
data_filepath = os.path.join(os.path.dirname(__file__), '..', 'data', 'input.txt')
text = data_utils.load_text_data(data_filepath)

In [None]:
# 2. Create the vocabulary
stoi, itos = data_utils.create_vocabulary(text)
vocab_size = len(stoi)

In [None]:
# 3. Encode the text
encoded_text = data_utils.encode_text(text, stoi)

In [None]:
# 4. Split the data (optional, but good practice to see performance on unseen data)
train_data, val_data, test_data = data_utils.split_data(encoded_text)

In [None]:
# 5. Hyperparameters (you could also load these from a config file)
batch_size = 32
block_size = 8 #context length
epochs = 10 #training epochs
learning_rate = 1e-2
model_save_path = os.path.join(os.path.dirname(__file__), '..', 'model.pth')

In [None]:
# --- Model Training or Loading ---
if os.path.exists(model_save_path):
    # Load the pre-trained model
    print("Loading the pre-trained model...")
    checkpoint = model_utils.load_checkpoint(model_save_path, bigram.BigramLanguageModel, vocab_size=vocab_size) #Pass vocab_size to the model.
    model = checkpoint['model']
    # Note: stoi and itos are not strictly part of the model, so load them separately
    # This is a design choice; you *could* include them in the checkpoint.
    loaded_stoi = checkpoint['stoi']
    loaded_itos = checkpoint['itos']

    # Sanity check: Ensure loaded vocab matches current vocab
    assert stoi == loaded_stoi, "Loaded stoi does not match current stoi!"
    assert itos == loaded_itos, "Loaded itos does not match current itos!"

else:
    print("Training a new model...")
    # 1. Create DataLoaders
    train_loader = data_utils.get_data_loader(train_data, batch_size=batch_size, block_size=block_size)
    val_loader = data_utils.get_data_loader(val_data, batch_size=batch_size, block_size=block_size) # Use validation data

    # 2. Instantiate the model and optimizer
    model = bigram.BigramLanguageModel(vocab_size)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    # 3. Training loop
    for epoch in range(epochs):
        model.train() #set to train mode
        for xb, yb in train_loader:
            # Forward pass and loss
            logits, loss = model(xb, yb)

            # Backward pass and optimization
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
        if (epoch+1) % (epochs//5) == 0 or epoch == 0:
             # --- Validation Loss (inside the training loop) ---
            model.eval()  # Set the model to evaluation mode
            with torch.no_grad():
                val_loss = 0
                for v_xb, v_yb in val_loader:
                    v_logits, v_loss = model(v_xb, v_yb)
                    val_loss+= v_loss.item()
                val_loss /= len(val_loader)  # Get the average validation loss

            print(f"Epoch {epoch+1}/{epochs}, Training Loss: {loss.item():.4f}, Validation Loss: {val_loss:.4f}")

    # 4. Save the trained model
    model_utils.save_checkpoint(model, optimizer, epochs, model_save_path, stoi=stoi, itos=itos, vocab_size=vocab_size, block_size=block_size) # Include stoi, itos in the checkpoint
    print(f"Model saved to {model_save_path}")

In [None]:
# --- Text Generation ---
print("\nGenerating text:")
generated_text = eval_utils.generate_text(model, stoi, itos, max_length=200, device=model_utils.get_device())
print(generated_text)

print("\nGenerating text starting from 'T':")
start_tokens = data_utils.encode_text("T", stoi)
generated_text = eval_utils.generate_text(model, stoi, itos, start_tokens=start_tokens, max_length=200, device=model_utils.get_device())
print(generated_text)

In [None]:
# --- (Optional) Visualization of Embedding Table ---
# This part is useful for visualizing what the model has learned,
# although with the bigram model, it's just a direct mapping.
try:
    plt.figure(figsize=(8, 8))
    plt.imshow(model.token_embedding_table.weight.detach().cpu().numpy())  # Move to CPU for plotting
    plt.title("Token Embedding Table")
    plt.xlabel("Token Index (Next Token)")
    plt.ylabel("Token Index (Current Token)")
    plt.colorbar(label="Logit Value")

    # Add labels to the axes (if the vocabulary is small enough)
    if vocab_size < 50:
        plt.xticks(range(vocab_size), [itos[i] for i in range(vocab_size)])
        plt.yticks(range(vocab_size), [itos[i] for i in range(vocab_size)])
    plt.show()

except ValueError as e:
    print(f"An error occurred during visualization: {e}")
    print("This may happen if Matplotlib is not properly configured.")
except RuntimeError as e:
    print(f"Runtime error during visualization: {e}") #Handle runtime errors.