In [59]:
import random
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, output_dim)
        )

    def forward(self, x):
        return self.model(x)

class HangmanRL:
    def __init__(self, word_list=None, max_wrong_guesses=6):
        if word_list is None:
            with open('words_250000_train.txt', 'r') as f:
                self.word_list = [line.strip() for line in f]
        else:
            self.word_list = word_list
        self.max_wrong_guesses = max_wrong_guesses
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.q_network = DQN(26 + 26 + 26, 26).to(self.device)
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=0.001)
        self.criterion = nn.MSELoss()
        self.reset()

    def reset(self):
        self.target_word = random.choice(self.word_list)
        self.current_word = ['_' for _ in self.target_word]
        self.guessed_letters = set()
        self.incorrect_letters = set()
        self.possible_words = self.word_list[:]
        self.wrong_guesses = 0
        return self.get_state()

    def get_state(self):
        known_vec = [1 if chr(i + ord('a')) in self.guessed_letters else 0 for i in range(26)]
        incorrect_vec = [1 if chr(i + ord('a')) in self.incorrect_letters else 0 for i in range(26)]
        word_vec = [0]*26
        for ch in self.current_word:
            if ch != '_':
                word_vec[ord(ch) - ord('a')] += 1
        return known_vec + incorrect_vec + word_vec

    def step(self, action):
        letter = chr(ord('a') + action)

        if letter in self.guessed_letters or letter in self.incorrect_letters:
            return self.get_state(), -1.0, False

        self.guessed_letters.add(letter)
        done = False
        prev_word = self.current_word[:]

        if letter in self.target_word:
            for i, ch in enumerate(self.target_word):
                if ch == letter:
                    self.current_word[i] = letter
        else:
            self.incorrect_letters.add(letter)
            self.wrong_guesses += 1

        self.update_possible_words()

        if '_' not in self.current_word:
            done = True
            reward = 50.0
        elif self.wrong_guesses >= self.max_wrong_guesses:
            done = True
            reward = -20.0
        elif self.current_word == prev_word:
            reward = -1.0
        else:
            reward = 2.0 if letter in self.target_word else -3.0

        return self.get_state(), reward, done

    def update_possible_words(self):
        pattern = ''.join(['.' if c == '_' else c for c in self.current_word])
        new_list = []
        for word in self.possible_words:
            if len(word) != len(self.current_word):
                continue
            if any(ch in word for ch in self.incorrect_letters):
                continue
            match = True
            for pc, wc in zip(pattern, word):
                if pc != '.' and pc != wc:
                    match = False
                    break
            if match:
                new_list.append(word)
        if new_list:
            self.possible_words = new_list

    def train(self, episodes=1000, gamma=0.9, epsilon=0.2):
        for ep in range(episodes):
            state = self.reset()
            done = False
            episode_letters = []
            
            while not done:
                if random.random() < epsilon:
                    action = random.randint(0, 25)
                else:
                    with torch.no_grad():
                        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
                        q_vals = self.q_network(state_tensor)
                        for i in range(26):
                            ch = chr(ord('a') + i)
                            if ch in self.guessed_letters or ch in self.incorrect_letters:
                                q_vals[0, i] = float('-inf')
                        action = q_vals.argmax().item()

                letter = chr(ord('a') + action)
                episode_letters.append(letter)
                
                next_state, reward, done = self.step(action)
                next_q = self.q_network(torch.FloatTensor(next_state).unsqueeze(0).to(self.device)).max().item()
                target = reward + gamma * next_q * (0 if done else 1)

                pred = self.q_network(torch.FloatTensor(state).unsqueeze(0).to(self.device))[0, action]
                loss = self.criterion(pred, torch.tensor(target).to(self.device))

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                state = next_state
            
            print(f"Episode {ep+1}/{episodes}: Word: {self.target_word}, Guessed: {''.join(episode_letters)}, {'Success' if '_' not in self.current_word else 'Failed'}")


In [60]:
# Initialize the RL agent
rl_agent = HangmanRL()
# Need to implement a train method or use the correct method for training
rl_agent.train(episodes=2000)  # Train for 2000 episodes - not implemented yet

Episode 349/2000: Word: earlship, Guessed: tanepoxkrz, Failed
Episode 350/2000: Word: almirah, Guessed: atjneok, Failed
Episode 351/2000: Word: lasala, Guessed: atnetowr, Failed
Episode 352/2000: Word: subers, Guessed: anteorozsl, Failed
Episode 353/2000: Word: precleaning, Guessed: anteorzslyignyf, Failed
Episode 354/2000: Word: mustachios, Guessed: oantecrzpd, Failed
Episode 355/2000: Word: stereochemistry, Guessed: atneorzsmgyliv, Failed
Episode 356/2000: Word: kephalins, Guessed: attneorzhszeylq, Failed
Episode 357/2000: Word: chape, Guessed: atnneorzs, Failed
Episode 358/2000: Word: pretensed, Guessed: atneyorbzk, Failed
Episode 359/2000: Word: gipps, Guessed: atneor, Failed
Episode 360/2000: Word: currack, Guessed: atrneozos, Failed
Episode 361/2000: Word: juxtaspinal, Guessed: aftneorzsw, Failed
Episode 362/2000: Word: cystignathine, Guessed: aatneocrrzetsyligfw, Failed
Episode 363/2000: Word: imprest, Guessed: atneomrzsypl, Failed
Episode 364/2000: Word: leprosaria, Guessed: an

KeyboardInterrupt: 