# BetaBae Text Generation: Bhagavad Gita Analysis

This notebook explores the text generation results from BetaBae trained on the Bhagavad Gita dataset.

## What We're Observing:
- **Attention Patterns**: How does the model learn to attend to relevant parts of the text?
- **Language Learning**: Can we see the model learning grammar, structure, and meaning?
- **Representation Evolution**: How do embeddings organize to capture linguistic structure?
- **Emergence of Understanding**: When does the model start generating coherent text?

## Key Questions:
1. Does attention focus on semantically related words?
2. Can we see the model learning sentence structure?
3. How do representations cluster around different concepts?
4. What patterns emerge in the generated text over time?


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import re

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Load dataset info
from betabae.text_logger import GitaDataset

dataset_path = '/home/abhaydjoshi/.cache/kagglehub/datasets/madhurpant/bhagavad-gita-verses-dataset/versions/1/bhagavad_gita_verses.csv'
dataset = GitaDataset(dataset_path, seq_len=64, vocab_size=128)

print("Dataset Information:")
print(f"Text length: {len(dataset.text):,} characters")
print(f"Vocabulary size: {len(dataset.vocab)}")
print(f"Number of sequences: {len(dataset.tokens):,}")
print(f"Sequence length: 64 tokens")
print(f"\nSample vocabulary: {dataset.vocab[:30]}")
print(f"\nSample text: {dataset.text[:300]}...")


In [None]:
# Check if training has produced any logs yet
log_dir = Path('text_outputs/logs')
if log_dir.exists():
    epoch_files = sorted(log_dir.glob('epoch_*.npz'))
    generation_files = sorted(log_dir.glob('generations_epoch_*.txt'))
    
    print(f"Found {len(epoch_files)} epoch files")
    print(f"Found {len(generation_files)} generation files")
    
    if epoch_files:
        print(f"Epoch files: {[f.name for f in epoch_files[:5]]}")
        
        # Load first epoch to check data structure
        data = np.load(epoch_files[0])
        print(f"\nData structure:")
        for key in data.keys():
            print(f"  {key}: {data[key].shape}")
    
    if generation_files:
        print(f"\nGeneration files: {[f.name for f in generation_files[:3]]}")
        
        # Show sample generations
        for gen_file in generation_files[:2]:
            print(f"\n{gen_file.name}:")
            with open(gen_file, 'r') as f:
                content = f.read()
                print(content[:200] + "...")
else:
    print("Training logs not found yet. Training may still be in progress...")


In [None]:
# Function to analyze text generation results
def analyze_text_results(log_dir):
    """Analyze text generation training results"""
    
    epoch_files = sorted(log_dir.glob('epoch_*.npz'))
    generation_files = sorted(log_dir.glob('generations_epoch_*.txt'))
    
    if not epoch_files:
        print("No training data found yet!")
        return
    
    # Collect all metrics
    all_losses = []
    all_perplexities = []
    epoch_numbers = []
    
    for epoch_file in epoch_files:
        data = np.load(epoch_file)
        all_losses.extend(data['losses'])
        
        if 'perplexities' in data:
            all_perplexities.extend(data['perplexities'])
        else:
            # Calculate perplexity from loss
            all_perplexities.extend([np.exp(loss) for loss in data['losses']])
        
        epoch_num = int(epoch_file.stem.split('_')[1])
        epoch_numbers.extend([epoch_num] * len(data['losses']))
    
    # Plot learning curves
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss over time
    axes[0, 0].plot(all_losses, alpha=0.7, linewidth=0.5)
    axes[0, 0].set_title('Training Loss Evolution')
    axes[0, 0].set_xlabel('Training Step')
    axes[0, 0].set_ylabel('Cross-Entropy Loss')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Perplexity over time
    axes[0, 1].plot(all_perplexities, alpha=0.7, linewidth=0.5, color='orange')
    axes[0, 1].set_title('Perplexity Evolution')
    axes[0, 1].set_xlabel('Training Step')
    axes[0, 1].set_ylabel('Perplexity')
    axes[0, 1].grid(True, alpha=0.3)
    
    # Loss by epoch
    epoch_losses = []
    for epoch_num in sorted(set(epoch_numbers)):
        epoch_losses.append(np.mean([loss for i, loss in enumerate(all_losses) if epoch_numbers[i] == epoch_num]))
    
    axes[1, 0].plot(epoch_losses, marker='o', markersize=4)
    axes[1, 0].set_title('Average Loss per Epoch')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Average Loss')
    axes[1, 0].grid(True, alpha=0.3)
    
    # Perplexity by epoch
    epoch_perplexities = []
    for epoch_num in sorted(set(epoch_numbers)):
        epoch_perplexities.append(np.mean([perp for i, perp in enumerate(all_perplexities) if epoch_numbers[i] == epoch_num]))
    
    axes[1, 1].plot(epoch_perplexities, marker='o', markersize=4, color='green')
    axes[1, 1].set_title('Average Perplexity per Epoch')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Average Perplexity')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics
    print(f"\nTraining Summary:")
    print(f"Total training steps: {len(all_losses)}")
    print(f"Number of epochs: {len(set(epoch_numbers))}")
    print(f"Initial loss: {all_losses[0]:.4f}")
    print(f"Final loss: {all_losses[-1]:.4f}")
    print(f"Initial perplexity: {all_perplexities[0]:.2f}")
    print(f"Final perplexity: {all_perplexities[-1]:.2f}")
    
    return all_losses, all_perplexities, epoch_numbers

# Run analysis if data is available
if log_dir.exists():
    analyze_text_results(log_dir)
else:
    print("Training data not available yet. Run this cell after training produces logs.")
