In [None]:
import os
import time
import json
import requests
import random
import string
import collections
import pickle

import torch
import torch.nn as nn
import torch.optim as optim

# For AMD GPU support via DirectML:
import torch_directml
device = torch_directml.device()
print("Using device:", device)

# =============================================================================
# 1. Global Constants & Vocabulary Setup
# =============================================================================

# Define our “top‐11” letters (chosen because they capture ~70% of letter frequency)
TOP_11_LETTERS = ["E", "I", "A", "R", "N", "O", "S", "T", "L", "C", "U"]
ALL_LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"

# Special tokens: PAD and MASK (we do not use UNK in this challenge)
PAD_TOKEN = "<PAD>"
MASK_TOKEN = "<MASK>"

# Create token-to-index mapping (we use only the provided vocabulary)
token_to_idx = {PAD_TOKEN: 0, MASK_TOKEN: 1}
start_idx = 2
for ch in ALL_LETTERS:
    token_to_idx[ch] = start_idx
    start_idx += 1
vocab_size = len(token_to_idx)  # e.g. 28 tokens (2 specials + 26 letters)
idx_to_token = {idx: token for token, idx in token_to_idx.items()}

PAD_IDX = token_to_idx[PAD_TOKEN]
MAX_LENGTH = 32  # fixed input length for our MLM

# =============================================================================
# 2. Transformer Model Definition (MLM with Negative Constraints)
# =============================================================================

class MiniTransformerMLMTop11(nn.Module):
    def __init__(self, vocab_size, d_model=128, nhead=4, num_layers=2, max_len=32, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.max_len = max_len
        
        # Embedding for tokens and positions.
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        
        # Linear projection to embed the 11-d negative vector.
        self.neg_embed = nn.Linear(11, d_model)
        
        # Transformer encoder layers (batch_first=True)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=256, 
            dropout=dropout, 
            activation='relu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Final classification layer to predict logits over the vocabulary.
        self.fc_out = nn.Linear(d_model, vocab_size)
    
    def forward(self, inp, neg_info):
        """
        inp: Tensor of shape (batch, seq_len) with token IDs.
        neg_info: Tensor of shape (batch, 11) with the binary negative vector.
        """
        batch_size, seq_len = inp.shape
        
        # Get token embeddings and add positional embeddings.
        tok_emb = self.token_embed(inp)               # (batch, seq_len, d_model)
        positions = torch.arange(seq_len, device=inp.device).unsqueeze(0)
        pos_emb = self.pos_embed(positions)             # (1, seq_len, d_model)
        x = tok_emb + pos_emb
        
        # Project the 11-d negative vector and add uniformly to each token position.
        neg_vec = self.neg_embed(neg_info)              # (batch, d_model)
        neg_vec = neg_vec.unsqueeze(1).expand(-1, seq_len, -1)  # (batch, seq_len, d_model)
        x = x + neg_vec
        
        # Pass through transformer encoder.
        x = self.transformer(x)                         # (batch, seq_len, d_model)
        logits = self.fc_out(x)                         # (batch, seq_len, vocab_size)
        return logits

# =============================================================================
# 3. Heuristic Solver (Using Only the Provided Training File)
# =============================================================================

class ConfidenceBasedHangmanSolver:
    def __init__(self, training_file):
        # Load training words ONLY from the provided file.
        with open(training_file, 'r', encoding='utf-8') as f:
            words = [line.strip().lower() for line in f if line.strip() and line.strip().isalpha()]
        self.training_words = set(words)  # Used to avoid final guesses that come from training
        self.guessed_letters = set()
        self.incorrect = set()  # Letters guessed that are not in the target word
        self.current_pattern = None  # Current masked pattern (e.g., "a__le")
        self.guess_count = 0

    def reset_game(self):
        self.guessed_letters = set()
        self.incorrect = set()
        self.current_pattern = None
        self.guess_count = 0

    def get_length_based_guess(self):
        """
        Basic heuristic: for example, try letters from the prioritized top 11 in order.
        (You could enhance this by using precomputed frequency per word-length.)
        """
        for c in "eaiornstlcu":
            if c not in self.guessed_letters:
                return c
        return None

    def get_global_fallback(self):
        # Fallback: simply choose the first unguessed letter from the alphabet.
        for c in string.ascii_lowercase:
            if c not in self.guessed_letters:
                return c
        return None

    def get_negative_vector(self):
        """
        Build the 11-d binary vector:
          For each letter in TOP_11_LETTERS, 1 if it has been guessed (incorrectly) and is not in the target.
        """
        return [1 if letter.lower() in self.incorrect else 0 for letter in TOP_11_LETTERS]

    def update_state(self, pattern, target_word):
        """
        Update solver state after each guess.
          pattern: the current masked word (e.g., "a__le")
          target_word: the true word (if available during simulation/training)
        """
        self.current_pattern = pattern.lower().replace(" ", "")
        self.guess_count += 1
        # Update incorrect letters (if a letter was guessed but not in target_word)
        for letter in self.guessed_letters:
            if letter not in target_word:
                self.incorrect.add(letter)

    def hybrid_guess(self, pattern, target_word=None, transformer_model=None):
        """
        Return the next guessed letter using a hybrid strategy:
          - If the clean pattern (without spaces) is longer than MAX_LENGTH,
            use heuristic only.
          - For the first two guesses, use heuristic.
          - Otherwise, if a transformer_model is provided, use its prediction.
          - In all cases, update self.guessed_letters accordingly.
        """
        clean_pattern = pattern.lower().replace(" ", "")
        
        # If word length > MAX_LENGTH, MLM cannot process it so use heuristic.
        if len(clean_pattern) > MAX_LENGTH:
            guess = self.get_length_based_guess() or self.get_global_fallback()
            self.guessed_letters.add(guess)
            return guess

        # For the first two guesses, use the heuristic.
        if self.guess_count < 2:
            guess = self.get_length_based_guess() or self.get_global_fallback()
            self.guessed_letters.add(guess)
            return guess

        # Otherwise, use the MLM-based approach if the transformer model is available.
        if transformer_model is not None:
            neg_vector = self.get_negative_vector()  # 11-d list (binary)
            # Convert pattern into token IDs
            token_ids = []
            for ch in clean_pattern:
                if ch == '_':
                    token_ids.append(token_to_idx[MASK_TOKEN])
                else:
                    token_ids.append(token_to_idx[ch.upper()])
            # Pad/truncate to MAX_LENGTH
            if len(token_ids) < MAX_LENGTH:
                token_ids += [token_to_idx[PAD_TOKEN]] * (MAX_LENGTH - len(token_ids))
            else:
                token_ids = token_ids[:MAX_LENGTH]
            
            input_tensor = torch.tensor([token_ids], dtype=torch.long).to(device)
            neg_tensor = torch.tensor([neg_vector], dtype=torch.float).to(device)
            transformer_model.eval()
            with torch.no_grad():
                logits = transformer_model(input_tensor, neg_tensor)  # (1, seq_len, vocab_size)
            
            # Find masked positions in the input
            mask_positions = [i for i, tid in enumerate(token_ids) if tid == token_to_idx[MASK_TOKEN]]
            if not mask_positions:
                # No masked positions—fallback to global.
                guess = self.get_global_fallback()
                self.guessed_letters.add(guess)
                return guess
            
            # Sum probabilities over all masked positions.
            probs_total = collections.Counter()
            for pos in mask_positions:
                prob = torch.softmax(logits[0, pos], dim=-1)
                for i in range(2, vocab_size):  # Ignore PAD and MASK indices 0 and 1.
                    letter = idx_to_token[i]
                    probs_total[letter.lower()] += prob[i].item()
            # Choose the letter with the highest summed probability that has not been guessed.
            for letter, _ in probs_total.most_common():
                if letter not in self.guessed_letters:
                    self.guessed_letters.add(letter)
                    return letter
        
        # If no transformer model is provided or MLM branch fails, fallback.
        guess = self.get_global_fallback()
        self.guessed_letters.add(guess)
        return guess

# =============================================================================
# 4. Hybrid Hangman API (Using our Heuristic-MLM Solver)
# =============================================================================

class HangmanAPIHybrid:
    def __init__(self, access_token, training_file, transformer_model, timeout=2000):
        self.hangman_url = "https://trexsim.com/trexsim/hangman"
        self.access_token = access_token
        self.session = requests.Session()
        self.timeout = timeout
        # hybrid solver
        self.solver = ConfidenceBasedHangmanSolver(training_file)
        self.transformer_model = transformer_model

    def start_game(self, practice=True, verbose=True):
        self.solver.reset_game()  # Reset solver state at game start.
        response = self.request("/new_game", {"practice": practice})
        if response.get("status") == "approved":
            game_id = response.get("game_id")
            word = response.get("word")
            tries_remains = response.get("tries_remains")
            if verbose:
                print(f"Game Started: ID={game_id}, Tries Remaining={tries_remains}, Word='{word}'")
            return game_id, word
        else:
            print("Failed to start a new game")
            return None, None

    def guess(self, word):
        """
        Given the current masked word (e.g., "_ A _ E"), use the hybrid solver to choose a guess.
        """
        # Here, we do not update the solver with the true word since during play we don't have it.
        return self.solver.hybrid_guess(word, transformer_model=self.transformer_model)

    def play_game(self, practice=True):
        game_id, word = self.start_game(practice=practice)
        if not game_id:
            return False
        
        tries_remains = 6
        while tries_remains > 0:
            guess_letter = self.guess(word)
            if not guess_letter:
                break
            
            # Send the guess to the server
            res = self.request("/guess_letter", {"game_id": game_id, "letter": guess_letter})
            status = res.get("status")
            tries_remains = res.get("tries_remains", 0)
            word = res.get("word", word)
            
            print(f"Game {game_id}: Guessed '{guess_letter}', Tries Remaining: {tries_remains}, Word: {word}")
            
            if status == "success":
                print(f"Game {game_id}: **WON**!")
                return True
            elif status == "failed":
                print(f"Game {game_id}: **LOST**!")
                return False
            # Else: game is ongoing.
        
        print(f"Game {game_id}: Out of tries, lost!")
        return False

    def my_status(self):
        return self.request("/my_status", {})

    def request(self, path, args=None):
        args = args or {}
        if self.access_token:
            args["access_token"] = self.access_token
        try:
            response = self.session.get(
                self.hangman_url + path,
                params=args,
                timeout=self.timeout,
                verify=False
            )
            response.raise_for_status()
            return response.json()
        except requests.RequestException as e:
            print(f"API request failed: {e}")
            return {}

# =============================================================================
# 5.Main FN
# =============================================================================

def main():
    # Replace with your actual access token and training file path.
    training_file = r"C:\Users\jaska\OneDrive\Desktop\Hangman Official\words_250000_train.txt"  
    access_token = "f03eaa6b58e7172ace02888be0bf36"  # Replace with your provided token.
    
    # Instantiate a transformer model (assume it has been trained or load a checkpoint).
    transformer_model = MiniTransformerMLMTop11(
        vocab_size=vocab_size,
        d_model=128,
        nhead=4,
        num_layers=2,
        max_len=MAX_LENGTH,
        dropout=0.1
    ).to(device)
    
    # load pretrained weights
    transformer_model.load_state_dict(torch.load(r"C:\Users\jaska\OneDrive\Desktop\Hangman Official\model_best.pt", map_location=device))
    
    # Instantiate the hybrid Hangman API using our solver and transformer.
    api = HangmanAPIHybrid(
        access_token=access_token,
        training_file=training_file,
        transformer_model=transformer_model,
        timeout=2000
    )
    
    # Play a practice game:
    success = api.play_game(practice=True)
    print("Practice game result:", "Win" if success else "Loss")
    
    # To check game statistics:
    status = api.my_status()
    print("Game status:", status)

if __name__ == "__main__":
    main()
