# Hangman RL Training Script

- Run this notebook to train the Probabilistic Oracle and RL Agent.
- Artifacts are saved to `./artifacts/oracle.pkl` and `./artifacts/agent.pkl`.
- Use `streamlit run app.py` to launch the app. The app will only load artifacts and will NOT retrain on each query.

In [None]:
import os
import random
import pickle
from collections import defaultdict
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

# --- Configuration ---
ALPHABET = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
MAX_WRONG_GUESSES = 6
NUM_TRAIN_EPISODES = 50000
NUM_EVAL_GAMES = 2000
CORPUS_FILE = 'corpus.txt'
TEST_FILE = 'test.txt'
ARTIFACT_DIR = 'artifacts'
ORACLE_PATH = os.path.join(ARTIFACT_DIR, 'oracle.pkl')
AGENT_PATH = os.path.join(ARTIFACT_DIR, 'agent.pkl')
TRAIN_PLOT_PATH = os.path.join(ARTIFACT_DIR, 'training_rewards.png')

# --- Reward Design ---
REWARD_WIN = 50
REWARD_LOSE = -50
REWARD_CORRECT_GUESS = 5
REWARD_WRONG_GUESS = -5
REWARD_REPEATED_GUESS = -2

# --- Default factories (picklable) ---
def one(): return 1
def alpha_len(): return len(ALPHABET)
def dd_one(): return defaultdict(one)
def zeros5(): return np.zeros(5)

# -----------------------------------------------
# PART 0: THE HANGMAN ENVIRONMENT
# -----------------------------------------------
class HangmanEnv:
    def __init__(self, word_list, max_wrong_guesses=MAX_WRONG_GUESSES):
        self.word_list = word_list
        self.max_wrong_guesses = max_wrong_guesses
        self.reset()

    def reset(self, specific_word=None):
        self.word = (specific_word if specific_word else random.choice(self.word_list)).upper()
        self.masked_word = "_" * len(self.word)
        self.guessed_letters = set()
        self.lives_left = self.max_wrong_guesses
        self.done = False
        return self._get_obs()

    def _get_obs(self):
        return (self.masked_word, self.guessed_letters, self.lives_left)

    def step(self, letter):
        letter = letter.upper()
        info = {"repeated": False, "wrong": False}
        if letter in self.guessed_letters:
            info["repeated"] = True
            return self._get_obs(), REWARD_REPEATED_GUESS, self.done, info
        self.guessed_letters.add(letter)
        if letter in self.word:
            new_masked_word = list(self.masked_word)
            for i, char in enumerate(self.word):
                if char == letter:
                    new_masked_word[i] = letter
            self.masked_word = "".join(new_masked_word)
            if "_" not in self.masked_word:
                self.done = True
                return self._get_obs(), REWARD_WIN, self.done, info
            else:
                return self._get_obs(), REWARD_CORRECT_GUESS, self.done, info
        else:
            self.lives_left -= 1
            info["wrong"] = True
            if self.lives_left <= 0:
                self.done = True
                return self._get_obs(), REWARD_LOSE, self.done, info
            else:
                return self._get_obs(), REWARD_WRONG_GUESS, self.done, info

# -----------------------------------------------
# PART 1: THE HMM (PROBABILISTIC ORACLE)
# -----------------------------------------------
class ProbabilisticOracle:
    def __init__(self):
        self.unigrams = defaultdict(one)
        self.total_unigrams = len(ALPHABET)
        self.bigrams = defaultdict(dd_one)
        self.bigram_totals = defaultdict(alpha_len)
        self.trigrams = defaultdict(dd_one)
        self.trigram_totals = defaultdict(alpha_len)
        self.positional_freq = defaultdict(dd_one)
        self.positional_totals = defaultdict(alpha_len)
        self.max_word_len = 0

    def train(self, corpus):
        print(f"Training oracle on {len(corpus)} words...")
        for word in corpus:
            word = word.strip().upper()
            if not word.isalpha():
                continue
            self.max_word_len = max(self.max_word_len, len(word))
            padded_word = f"^{word}$"
            for i, char in enumerate(word):
                self.unigrams[char] += 1
                self.total_unigrams += 1
                self.positional_freq[i][char] += 1
                self.positional_totals[i] += 1
            for i in range(len(padded_word) - 1):
                c1, c2 = padded_word[i:i+2]
                self.bigrams[c1][c2] += 1
                self.bigram_totals[c1] += 1
                if i < len(padded_word) - 2:
                    c1, c2, c3 = padded_word[i:i+3]
                    self.trigrams[c1+c2][c3] += 1
                    self.trigram_totals[c1+c2] += 1
        print("Oracle training complete.")

    def get_letter_probabilities(self, masked_word, guessed_letters):
        scores = defaultdict(float)
        unguessed_letters = [l for l in ALPHABET if l not in guessed_letters]
        if not unguessed_letters:
            return {}
        padded_masked = f"^{masked_word}$"
        for letter in unguessed_letters:
            scores[letter] = np.log(self.unigrams[letter] / self.total_unigrams)
            for i, char in enumerate(masked_word):
                if char == "_":
                    p_i = i + 1
                    c_prev = padded_masked[p_i - 1]
                    c_next = padded_masked[p_i + 1]
                    c_prev2 = padded_masked[p_i - 2] if p_i > 1 else None
                    pos_score = self.positional_freq[i][letter] / self.positional_totals[i]
                    bi_score_1 = self.bigrams[c_prev][letter] / self.bigram_totals[c_prev]
                    bi_score_2 = self.bigrams[letter][c_next] / self.bigram_totals[letter]
                    tri_score = 1.0
                    if c_prev2:
                        tri_score = self.trigrams[c_prev2+c_prev][letter] / self.trigram_totals[c_prev2+c_prev]
                    scores[letter] += np.log(pos_score) + np.log(bi_score_1) + np.log(bi_score_2) + (np.log(tri_score) * 2.0)
        if not scores:
            return {}
        max_score = max(scores.values())
        exp_scores = {l: np.exp(s - max_score) for l, s in scores.items()}
        total_exp_score = sum(exp_scores.values())
        if total_exp_score == 0:
            return {l: 1.0 / len(unguessed_letters) for l in unguessed_letters}
        probs = {l: s / total_exp_score for l, s in exp_scores.items()}
        return probs

# -----------------------------------------------
# PART 2: THE REINFORCEMENT LEARNING AGENT
# -----------------------------------------------
class HangmanRLAgent:
    def __init__(self, learning_rate=0.1, discount_factor=0.9,
                 exploration_rate=1.0, exploration_decay=0.9999, min_exploration=0.01):
        self.q_table = defaultdict(zeros5)
        self.lr = learning_rate
        self.gamma = discount_factor
        self.epsilon = exploration_rate
        self.epsilon_decay = exploration_decay
        self.min_epsilon = min_exploration
        self.vowels = "AEIOU"

    def _get_state(self, lives_left, masked_word):
        num_blanks = masked_word.count("_")
        if num_blanks == 1: blanks_state = 1
        elif num_blanks == 2: blanks_state = 2
        elif num_blanks == 3: blanks_state = 3
        elif num_blanks <= 5: blanks_state = 4
        else: blanks_state = 5
        return (lives_left, blanks_state)

    def _get_letter_from_action(self, action_idx, hmm_probs, guessed_letters):
        sorted_probs = sorted(hmm_probs.items(), key=lambda item: item[1], reverse=True)
        unguessed_sorted = [l for l, p in sorted_probs if l not in guessed_letters]
        if not unguessed_sorted:
            return None
        unguessed_vowels = [l for l in unguessed_sorted if l in self.vowels]
        unguessed_consonants = [l for l in unguessed_sorted if l not in self.vowels]
        letter_to_guess = None
        if action_idx == 0: letter_to_guess = unguessed_sorted[0]
        elif action_idx == 1: letter_to_guess = unguessed_sorted[1] if len(unguessed_sorted) > 1 else None
        elif action_idx == 2: letter_to_guess = unguessed_sorted[2] if len(unguessed_sorted) > 2 else None
        elif action_idx == 3: letter_to_guess = unguessed_vowels[0] if unguessed_vowels else None
        elif action_idx == 4: letter_to_guess = unguessed_consonants[0] if unguessed_consonants else None
        if letter_to_guess is None:
            letter_to_guess = unguessed_sorted[0]
        return letter_to_guess

    def choose_action(self, state, hmm_probs, guessed_letters, is_training=True):
        if is_training and np.random.rand() < self.epsilon:
            action_idx = np.random.randint(5)
        else:
            action_idx = int(np.argmax(self.q_table[state]))
        letter = self._get_letter_from_action(action_idx, hmm_probs, guessed_letters)
        return action_idx, letter

    def update_q_table(self, state, action_idx, reward, next_state, done):
        old_q_value = self.q_table[state][action_idx]
        next_max_q = np.max(self.q_table[next_state]) if not done else 0
        new_q_value = old_q_value + self.lr * (reward + self.gamma * next_max_q - old_q_value)
        self.q_table[state][action_idx] = new_q_value

    def decay_epsilon(self):
        self.epsilon = max(self.min_epsilon, self.epsilon * self.epsilon_decay)

    def train(self, env, oracle, num_episodes):
        print(f"Training RL agent for {num_episodes} episodes...")
        rewards_per_episode = []
        for episode in tqdm(range(num_episodes)):
            obs = env.reset()
            masked_word, guessed_letters, lives_left = obs
            done = False
            total_reward = 0
            while not done:
                hmm_probs = oracle.get_letter_probabilities(masked_word, guessed_letters)
                state = self._get_state(lives_left, masked_word)
                action_idx, letter_to_guess = self.choose_action(state, hmm_probs, guessed_letters, is_training=True)
                if letter_to_guess is None:
                    break
                next_obs, reward, done, info = env.step(letter_to_guess)
                next_masked_word, next_guessed_letters, next_lives_left = next_obs
                next_state = self._get_state(next_lives_left, next_masked_word)
                self.update_q_table(state, action_idx, reward, next_state, done)
                masked_word, guessed_letters, lives_left = next_masked_word, next_guessed_letters, next_lives_left
                total_reward += reward
            self.decay_epsilon()
            rewards_per_episode.append(total_reward)
        print("RL agent training complete.")
        os.makedirs(ARTIFACT_DIR, exist_ok=True)
        plt.figure(figsize=(12, 6))
        plt.plot(rewards_per_episode, label='Total Reward per Episode')
        moving_avg = np.convolve(rewards_per_episode, np.ones(100)/100, mode='valid')
        plt.plot(moving_avg, label='100-episode Moving Average', linewidth=2, color='red')
        plt.title("RL Agent Training Progress")
        plt.xlabel("Episode")
        plt.ylabel("Total Reward")
        plt.legend()
        plt.grid(True)
        plt.savefig(TRAIN_PLOT_PATH)
        print(f"Training reward plot saved to '{TRAIN_PLOT_PATH}'")

    def evaluate(self, env, oracle, test_words, num_games=NUM_EVAL_GAMES):
        print(f"Evaluating agent on {num_games} test words...")
        total_success = 0
        total_wrong_guesses = 0
        total_repeated_guesses = 0
        if len(test_words) < num_games:
            eval_words = test_words * (num_games // len(test_words)) + test_words[:num_games % len(test_words)]
        else:
            eval_words = random.sample(test_words, num_games)
        for word in tqdm(eval_words):
            obs = env.reset(specific_word=word)
            masked_word, guessed_letters, lives_left = obs
            done = False
            game_wrong_guesses = 0
            game_repeated_guesses = 0
            while not done:
                hmm_probs = oracle.get_letter_probabilities(masked_word, guessed_letters)
                state = self._get_state(lives_left, masked_word)
                action_idx, letter_to_guess = self.choose_action(state, hmm_probs, guessed_letters, is_training=False)
                if letter_to_guess is None:
                    break
                next_obs, reward, done, info = env.step(letter_to_guess)
                if info["repeated"]: game_repeated_guesses += 1
                if info["wrong"]: game_wrong_guesses += 1
                masked_word, guessed_letters, lives_left = next_obs
            if "_" not in masked_word:
                total_success += 1
            total_wrong_guesses += game_wrong_guesses
            total_repeated_guesses += game_repeated_guesses
        success_rate = total_success / num_games
        final_score = (success_rate * 2000) - (total_wrong_guesses * 5) - (total_repeated_guesses * 2)
        return {
            'games': num_games,
            'wins': total_success,
            'success_rate': success_rate,
            'wrong_guesses': total_wrong_guesses,
            'repeated_guesses': total_repeated_guesses,
            'final_score': final_score
        }

# -----------------------------------------------
# Serialization helpers (avoid lambda pickling issues)
# -----------------------------------------------
def serialize_oracle(oracle: ProbabilisticOracle) -> dict:
    return {
        'unigrams': dict(oracle.unigrams),
        'total_unigrams': oracle.total_unigrams,
        'bigrams': {k: dict(v) for k, v in oracle.bigrams.items()},
        'bigram_totals': dict(oracle.bigram_totals),
        'trigrams': {k: dict(v) for k, v in oracle.trigrams.items()},
        'trigram_totals': dict(oracle.trigram_totals),
        'positional_freq': {int(k): dict(v) for k, v in oracle.positional_freq.items()},
        'positional_totals': {int(k): v for k, v in oracle.positional_totals.items()},
        'max_word_len': oracle.max_word_len
    }

def deserialize_oracle(state: dict) -> ProbabilisticOracle:
    o = ProbabilisticOracle()
    o.unigrams = defaultdict(one, state['unigrams'])
    o.total_unigrams = state['total_unigrams']
    o.bigrams = defaultdict(dd_one, {k: defaultdict(one, v) for k, v in state['bigrams'].items()})
    o.bigram_totals = defaultdict(alpha_len, state['bigram_totals'])
    o.trigrams = defaultdict(dd_one, {k: defaultdict(one, v) for k, v in state['trigrams'].items()})
    o.trigram_totals = defaultdict(alpha_len, state['trigram_totals'])
    o.positional_freq = defaultdict(dd_one, {int(k): defaultdict(one, v) for k, v in state['positional_freq'].items()})
    o.positional_totals = defaultdict(alpha_len, {int(k): v for k, v in state['positional_totals'].items()})
    o.max_word_len = state['max_word_len']
    return o

def serialize_agent(agent: HangmanRLAgent) -> dict:
    q_plain = {k: agent.q_table[k].tolist() for k in agent.q_table.keys()}
    return {
        'q_table': q_plain,
        'lr': agent.lr,
        'gamma': agent.gamma,
        'epsilon': agent.epsilon,
        'epsilon_decay': agent.epsilon_decay,
        'min_epsilon': agent.min_epsilon,
        'vowels': agent.vowels
    }

def deserialize_agent(state: dict) -> HangmanRLAgent:
    a = HangmanRLAgent()
    a.q_table = defaultdict(zeros5, {eval(k) if isinstance(k, str) and k.startswith('(') else k: np.array(v) for k, v in state['q_table'].items()})
    a.lr = state['lr']
    a.gamma = state['gamma']
    a.epsilon = state['epsilon']
    a.epsilon_decay = state['epsilon_decay']
    a.min_epsilon = state['min_epsilon']
    a.vowels = state['vowels']
    return a

def save_oracle(oracle, path=ORACLE_PATH):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, 'wb') as f:
        pickle.dump(serialize_oracle(oracle), f)
    print(f"Saved oracle to {path}")

def save_agent(agent, path=AGENT_PATH):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, 'wb') as f:
        pickle.dump(serialize_agent(agent), f)
    print(f"Saved agent to {path}")

def load_words(filepath):
    if not os.path.exists(filepath):
        print(f"Error: File not found at {filepath}")
        return None
    with open(filepath, 'r') as f:
        words = [line.strip().upper() for line in f if line.strip().isalpha()]
    return words

In [None]:
# Train and save artifacts (run this cell manually when you want to train)
train_words = load_words(CORPUS_FILE)
test_words = load_words(TEST_FILE)
if train_words is None or test_words is None:
    print("Could not load word files. Ensure 'corpus.txt' and 'test.txt' exist in this directory.")
else:
    oracle = ProbabilisticOracle()
    oracle.train(train_words)
    env = HangmanEnv(train_words)
    agent = HangmanRLAgent()
    agent.train(env, oracle, num_episodes=NUM_TRAIN_EPISODES)
    # Save artifacts for the app
    save_oracle(oracle, ORACLE_PATH)
    save_agent(agent, AGENT_PATH)
    # Optional quick evaluation (uses pre-trained agent, not required by the app)
    eval_env = HangmanEnv(test_words)
    metrics = agent.evaluate(eval_env, oracle, test_words, num_games=NUM_EVAL_GAMES)
    print("\n--- Evaluation Results ---")
    print(metrics)