In [48]:
import random
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque, namedtuple, defaultdict
from tqdm import tqdm
import math

# Experience Replay Buffer
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward', 'done'))

class ReplayBuffer:
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

class HangmanRLAgent:
    def __init__(self, vocab, max_lives=6, epsilon=1.0, gamma=0.99, lr=0.00025,
                 epsilon_decay_steps=200000, # Slower decay based on total steps
                 min_epsilon=0.05,
                 buffer_size=50000,
                 batch_size=128,
                 target_update_freq=500): # Update target net less frequently (steps)
        self.vocab = vocab
        self.char_to_idx = {ch: i for i, ch in enumerate(vocab)}
        self.idx_to_char = {i: ch for ch, i in self.char_to_idx.items()}
        self.vocab_size = len(vocab)

        self.max_word_len = 25
        # State includes: one-hot encoded masked word + binary vector for guessed letters
        self.state_dim = (self.vocab_size + 1) * self.max_word_len + self.vocab_size
        self.hidden_size = 512 # Increased hidden size
        self.output_size = self.vocab_size

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        self.policy_net = self.build_model().to(self.device)
        self.target_net = self.build_model().to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval() # Target network is only for inference

        self.optimizer = optim.AdamW(self.policy_net.parameters(), lr=lr, amsgrad=True)
        self.criterion = nn.SmoothL1Loss() # Huber Loss
        self.replay_buffer = ReplayBuffer(buffer_size)

        self.max_lives = max_lives
        self.gamma = gamma
        self.epsilon = epsilon
        self.initial_epsilon = epsilon
        self.epsilon_decay_steps = epsilon_decay_steps # Steps over which epsilon decays
        self.min_epsilon = min_epsilon
        self.batch_size = batch_size
        self.target_update_freq = target_update_freq # In terms of optimization steps

        self.total_steps = 0
        self.total_optim_steps = 0


    def build_model(self):
        # Deeper network
        return nn.Sequential(
            nn.Linear(self.state_dim, self.hidden_size),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(self.hidden_size, self.hidden_size // 2), # Extra layer
            nn.ReLU(),
            nn.Linear(self.hidden_size // 2, self.output_size)
        )

    def encode_state(self, masked_word, guessed_letters):
        # Encode masked word (one-hot + mask token)
        word_state = []
        for i in range(self.max_word_len):
            one_hot = [0] * (self.vocab_size + 1)
            if i < len(masked_word):
                ch = masked_word[i]
                if ch == '_':
                    one_hot[-1] = 1 # Mask token
                elif ch in self.char_to_idx:
                    one_hot[self.char_to_idx[ch]] = 1
                else:
                    # Should not happen if dictionary is pre-filtered
                    # print(f"Warning: Unexpected character '{ch}' in masked word. Treating as mask.")
                    one_hot[-1] = 1 # Treat unknown revealed char as mask
            else:
                 one_hot[-1] = 1 # Pad with mask token
            word_state.extend(one_hot)

        # Encode guessed letters (binary vector)
        guessed_state = [1.0 if c in guessed_letters else 0.0 for c in self.vocab]

        # Combine states
        state = word_state + guessed_state

        return torch.tensor(state, dtype=torch.float32, device=self.device)

    def guess(self, state, guessed_letters):
        """ Get action using epsilon-greedy policy """
        self.total_steps += 1
        # Epsilon decay based on total steps (slower decay)
        self.epsilon = self.min_epsilon + \
            (self.initial_epsilon - self.min_epsilon) * \
            math.exp(-1. * self.total_steps / self.epsilon_decay_steps)

        if random.random() < self.epsilon:
            # Exploration: Choose a random valid action
            options = [c for c in self.vocab if c not in guessed_letters]
            if not options: # Should not happen if game isn't over
                options = list(self.vocab) # Fallback just in case
            action_char = random.choice(options)
            action_idx = self.char_to_idx[action_char]
            return action_idx, action_char
        else:
            # Exploitation: Choose the best action according to the policy network
            with torch.no_grad():
                self.policy_net.eval() # Set to eval mode for inference
                q_values = self.policy_net(state)
                self.policy_net.train() # Set back to train mode

                # Mask already guessed letters by setting their Q-values to negative infinity
                for i, ch in enumerate(self.vocab):
                    if ch in guessed_letters:
                        q_values[i] = -float('inf')

                # Handle case where all actions might be masked (shouldn't happen in valid states)
                if torch.isinf(q_values).all():
                     # Fallback: choose a random valid action if all are masked
                     options = [c for c in self.vocab if c not in guessed_letters]
                     if not options: options = list(self.vocab)
                     action_char = random.choice(options)
                     action_idx = self.char_to_idx[action_char]
                     # print("Warning: All Q-values were -inf, choosing random valid action.")
                     return action_idx, action_char

                best_action_idx = torch.argmax(q_values).item()
                best_action_char = self.idx_to_char[best_action_idx]
                return best_action_idx, best_action_char

    def learn(self):
        """ Sample batch from replay buffer and perform Double Q-learning update """
        if len(self.replay_buffer) < self.batch_size:
            return None # Not enough samples yet

        transitions = self.replay_buffer.sample(self.batch_size)
        # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for details)
        batch = Transition(*zip(*transitions))

        # Create tensors from batch components
        state_batch = torch.stack(batch.state)
        action_batch = torch.tensor(batch.action, dtype=torch.long, device=self.device).unsqueeze(1)
        reward_batch = torch.tensor(batch.reward, dtype=torch.float32, device=self.device)
        next_state_batch = torch.stack(batch.next_state)
        done_mask = torch.tensor(batch.done, dtype=torch.bool, device=self.device) # Boolean mask for final states

        # --- Q-value computation ---
        # Get Q(s_t, a) for the actions taken
        # We get Q(s_t) from policy_net and select the column corresponding to the action taken
        self.policy_net.train() # Ensure policy net is in train mode
        state_action_values = self.policy_net(state_batch).gather(1, action_batch)

        # --- Target Q-value computation (Double DQN) ---
        next_state_values = torch.zeros(self.batch_size, device=self.device)
        with torch.no_grad():
            # 1. Select the best action for s_{t+1} using the *policy* network
            # We only need to compute this for non-final states
            next_policy_q_values = self.policy_net(next_state_batch[~done_mask])
            best_next_actions = next_policy_q_values.argmax(1).unsqueeze(1) # Get indices of max Q-values

            # 2. Evaluate the Q-value of that action using the *target* network
            next_target_q_values = self.target_net(next_state_batch[~done_mask])
            # Use gather to select the Q-value corresponding to the best action chosen by the policy net
            next_state_values[~done_mask] = next_target_q_values.gather(1, best_next_actions).squeeze(1)

        # Compute the expected Q values: reward + gamma * Q_target(s_{t+1}, argmax_a Q_policy(s_{t+1}, a))
        expected_state_action_values = (next_state_values * self.gamma) + reward_batch

        # --- Loss Calculation and Optimization ---
        # Compute Huber loss between current Q values and target Q values
        loss = self.criterion(state_action_values, expected_state_action_values.unsqueeze(1))

        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        # Gradient Clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
        self.optimizer.step()

        self.total_optim_steps += 1

        # --- Target Network Update ---
        # Periodically update the target network by copying weights from the policy network
        if self.total_optim_steps % self.target_update_freq == 0:
            # print(f"Updating target network at optimization step {self.total_optim_steps}")
            self.target_net.load_state_dict(self.policy_net.state_dict())

        return loss.item()


    def train_episode(self, word):
        """ Run a single episode, store transitions, return stats """
        guessed_letters = set()
        masked_word = ['_'] * len(word)
        lives = self.max_lives
        episode_reward = 0.0
        steps = 0
        step_penalty = -0.01 # Small penalty per step to encourage efficiency

        current_masked_word_str = "".join(masked_word)
        state = self.encode_state(current_masked_word_str, guessed_letters)

        while lives > 0 and '_' in masked_word:
            action_idx, guess_char = self.guess(state, guessed_letters)

            # Determine reward and next state based on the guess
            if guess_char in guessed_letters:
                 # Penalty for re-guessing (wasted step)
                 reward = -0.1
                 next_masked_word_str = current_masked_word_str # State doesn't change
                 # No life lost for re-guessing
            else:
                guessed_letters.add(guess_char)
                correct_guess = False
                newly_revealed_count = 0
                next_masked_word_list = list(current_masked_word_str)

                for i, char_in_word in enumerate(word):
                    if char_in_word == guess_char:
                        if next_masked_word_list[i] == '_':
                           next_masked_word_list[i] = guess_char
                           newly_revealed_count += 1
                        correct_guess = True

                next_masked_word_str = "".join(next_masked_word_list)

                # Adjusted Reward Structure
                if correct_guess:
                    # Positive reward for correct guess, scaled by reveal count
                    reward = 0.1 + 0.5 * (newly_revealed_count / len(word))
                else:
                    # Increased penalty for wrong guess
                    reward = -0.3 # Was -0.2
                    lives -= 1

            # Apply step penalty
            reward += step_penalty

            # Check for game end conditions
            game_over = (lives <= 0 or '_' not in next_masked_word_str)
            win = (game_over and '_' not in next_masked_word_str)

            # Add adjusted terminal rewards/penalties (applied on the final step)
            if win:
                reward += 4.0 # Was +5.0
            elif game_over and not win:
                reward += -3.0 # Was -2.0

            episode_reward += reward
            next_state = self.encode_state(next_masked_word_str, guessed_letters)

            # Store transition in replay buffer
            # Ensure all components are correctly formatted if needed (state should be tensor already)
            self.replay_buffer.push(state, action_idx, next_state, reward, game_over)

            # Move to the next state
            state = next_state
            current_masked_word_str = next_masked_word_str
            steps += 1

            # Note: Learning step is now handled in the main train loop based on total_steps

        # Return episode statistics
        return episode_reward, steps, win, self.epsilon


    def train(self, dictionary, episodes=20000, learn_every_n_steps=4, updates_per_step=1):
        """ Main training loop """
        wins_in_log_interval = 0
        total_wins = 0
        total_episode_rewards = 0
        total_steps_in_log_interval = 0
        total_loss_in_log_interval = 0
        n_losses = 0

        # Pre-filter dictionary
        valid_dictionary = [
            word.lower() for word in dictionary
            if len(word) <= self.max_word_len and all(c in self.vocab for c in word.lower()) and len(word) > 0
        ]
        print(f"Filtered dictionary size: {len(valid_dictionary)} words")
        if not valid_dictionary:
            print("Error: No valid words found in the dictionary for training.")
            return

        progress_bar = tqdm(range(episodes), desc="Training Progress")
        for ep in progress_bar:
            word = random.choice(valid_dictionary)

            # Run one episode
            ep_reward, ep_steps, win, current_epsilon = self.train_episode(word)

            # Accumulate stats for logging interval
            total_episode_rewards += ep_reward
            total_steps_in_log_interval += ep_steps
            if win:
                wins_in_log_interval += 1
                total_wins += 1

            # Perform learning steps periodically based on total steps taken across episodes
            # Check if enough steps have passed since the last learning phase
            # Note: self.total_steps is incremented inside self.guess()
            if self.total_steps // learn_every_n_steps > (self.total_steps - ep_steps) // learn_every_n_steps:
                 if len(self.replay_buffer) >= self.batch_size:
                    for _ in range(updates_per_step): # Perform multiple updates if specified
                        loss = self.learn()
                        if loss is not None:
                            total_loss_in_log_interval += loss
                            n_losses += 1


            # Log progress periodically
            log_interval = 100 # Log every 100 episodes
            if (ep + 1) % log_interval == 0:
                avg_reward = total_episode_rewards / log_interval
                avg_loss = total_loss_in_log_interval / n_losses if n_losses > 0 else 0
                # Calculate win rate over the *entire training* so far for a smoother trend
                overall_win_rate = total_wins / (ep + 1)
                # Calculate win rate *within the current log interval*
                interval_win_rate = wins_in_log_interval / log_interval
                avg_steps = total_steps_in_log_interval / log_interval

                progress_bar.set_description(
                    f"Ep {ep+1}/{episodes} | Win Rate (Overall): {overall_win_rate:.2%} | "
                    f"Win Rate (Last {log_interval}): {interval_win_rate:.2%} | "
                    f"Avg Reward: {avg_reward:.2f} | Avg Loss: {avg_loss:.4f} | "
                    f"Avg Steps: {avg_steps:.1f} | Epsilon: {current_epsilon:.4f}"
                )

                # Reset stats for the next logging interval
                total_episode_rewards = 0
                total_steps_in_log_interval = 0
                total_loss_in_log_interval = 0
                wins_in_log_interval = 0
                n_losses = 0


    def save_model(self, path):
        """ Saves the policy network state dictionary. """
        torch.save(self.policy_net.state_dict(), path)
        print(f"Model policy network saved to {path}")

    def load_model(self, path):
        """ Loads the policy network state dictionary and syncs the target network. """
        try:
            # Load state dict onto the correct device
            state_dict = torch.load(path, map_location=self.device)
            self.policy_net.load_state_dict(state_dict)
            self.target_net.load_state_dict(self.policy_net.state_dict()) # Sync target net

            # Ensure models are on the correct device (redundant if map_location worked, but safe)
            self.policy_net.to(self.device)
            self.target_net.to(self.device)

            # Set networks to evaluation mode by default after loading
            self.policy_net.eval()
            self.target_net.eval()
            print(f"Model loaded from {path} and target network synced.")
        except FileNotFoundError:
            print(f"Error: Model file not found at {path}. Starting with untrained model.")
        except Exception as e:
            print(f"Error loading model from {path}: {e}. Starting with untrained model.")
            # Reset networks to initial state if loading fails badly
            self.policy_net = self.build_model().to(self.device)
            self.target_net = self.build_model().to(self.device)
            self.target_net.load_state_dict(self.policy_net.state_dict())
            self.target_net.eval()
            # Re-initialize optimizer with potentially new parameters
            self.optimizer = optim.AdamW(self.policy_net.parameters(), lr=self.optimizer.defaults['lr'], amsgrad=True)

In [49]:
with open("words_250000_train.txt", "r") as f:
    dictionary = f.read().splitlines()

# Step 2: Initialize and train agent
agent = HangmanRLAgent(vocab=list("abcdefghijklmnopqrstuvwxyz"))
agent.train(dictionary, episodes=200000)  # You can increase this as needed

# Step 3: Save the trained model
agent.save_model("hangman_model.pth")

Using device: cpu
Filtered dictionary size: 227295 words


Ep 200000/200000 | Win Rate (Overall): 10.39% | Win Rate (Last 100): 8.00% | Avg Reward: -2.97 | Avg Loss: 0.4208 | Avg Steps: 11.4 | Epsilon: 0.0500: 100%|██████████| 200000/200000 [13:10:42<00:00,  4.22it/s]     


Model policy network saved to hangman_model.pth
