# Simulating games of Hangman using a heuristic policy

In this notebook, we shall run example games of Hangman using our pre-trained transformer model as shown in the notebook `pre_train_notebook.ipynb` by constructing an intuitive heuristic policy.

In [1]:
import torch
from utils.utils import MyTokenizer, MyMasker, TextDataset
from torch.utils.data import Dataset, DataLoader, random_split

# prepare the dataset
dataset = TextDataset('./data/words_250000_train.txt')

train_split_percent = 0.99
train_size = int(train_split_percent * len(dataset))
test_size = len(dataset) - train_size

print(f'Training size : {train_size}\nValidation size : {test_size}')

# Using the same seed as we did for training
train_dataset, val_dataset = random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(0))

Training size : 225027
Validation size : 2273


### Importing custom Hangman gym-based env
Check out the source code of the environment under `env` directory. 
* Follows the gym protocol.
* Is vectorized and can support multithreading for parallel computation.

In [2]:
import gymnasium as gym
from env.hangman import HangmanEnv

In [3]:
# Choose and split the dataset into `n` buckets for generating words for hangman

n = 32  # number of environments to run in parallel

dataset = val_dataset
n_datasets = random_split(val_dataset, [1/n]*n)

def make_envs(dataset):
    def thunk():
        return HangmanEnv(dataset=dataset)
    return thunk

envs = gym.vector.SyncVectorEnv(
    [make_envs(ds) for ds in n_datasets]
)

### Implementing our Agent which will interact with the Hangman environment

In [4]:
import torch.nn as nn
import torch.nn.functional as F

from model.Models import Transformer

class Agent(nn.Module):
    def __init__(self):
        super(Agent, self).__init__()
        
        # pretrained model outputs raw logits of `expected` word from supervised learning
        self.pretrainedLLM = Transformer(src_vocab=28, d_model=128, max_seq_len=32, N=12, heads=8, dropout=0.1)
        self.pretrainedLLM.load_state_dict(torch.load('./weights/aaa_best_weights'))
        
    def act(self, x, guessed_letters):
        valid_actions = self.get_valid_actions(guessed_letters)
        
        mask = (x != 0).unsqueeze(-2)
        logits = self.pretrainedLLM(x)
        probs = nn.functional.softmax(logits, dim=-1)

        probs = torch.matmul(1.*mask, probs)  # effectively adds the probs row-wise for each action / character
        probs = probs.squeeze(1)
        probs = probs / torch.sum(probs)

        fprobs = torch.mul(probs, valid_actions)  # zero out probabilites of invalid actions
        action = torch.argmax(fprobs, dim=-1)  # follow a greedy hueristic based policy on letter frequency
        
        return action, None
    
    @staticmethod
    def get_valid_actions(guessed_letters):
    
        valid_actions = torch.ones((len(guessed_letters), 28)).to('cuda')
        valid_actions[:,  0] = 0.
        valid_actions[:, -1] = 0.

        for i, s in enumerate(guessed_letters):
            for char in s:
                idx = ord(char) - ord('a') + 1
                valid_actions[i, idx] = 0.

        return valid_actions

In [5]:
# Initialize the agent
device = 'cuda' if torch.cuda.is_available() else 'cpu'

agent = Agent().to(device)
agent.eval()

Agent(
  (pretrainedLLM): Transformer(
    (encoder): Encoder(
      (embed): Embedder(
        (embed): Embedding(28, 128, padding_idx=0)
      )
      (pe): PositionalEncoder(
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (layers): ModuleList(
        (0-11): 12 x EncoderLayer(
          (norm_1): Norm()
          (norm_2): Norm()
          (attn): MultiHeadAttention(
            (q_linear): Linear(in_features=128, out_features=128, bias=True)
            (v_linear): Linear(in_features=128, out_features=128, bias=True)
            (k_linear): Linear(in_features=128, out_features=128, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (out): Linear(in_features=128, out_features=128, bias=True)
          )
          (ff): FeedForward(
            (linear_1): Linear(in_features=128, out_features=2048, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (linear_2): Linear(in_features=2048, out_features=128, bias=True)
    

In [6]:
import time

# dictionary of results per given hidden word
wins = {}
total_games = {}

start_time = time.time()

state, info = envs.reset()
state = torch.tensor(state).to(device)

print_counter = 0
while True:
    action_ints, _ = agent.act(state, info['guessed_letters'])
    
    # Convert to action_int to action_str guesses
    action_strs =  [chr(idx-1 + ord('a')) for idx in action_ints]
    
    # Take step in the envs
    state, reward, terminated, truncated, info = envs.step(action_strs)
    state = torch.tensor(state).to(device)
    done = (terminated | truncated)
    
    # Print running statistics
    if done.any():
        for hidden_word, r, d in zip(info['hidden_word'], reward, done):
            if int(d) == 1:
                total_games[hidden_word] = total_games.get(hidden_word, 0) + int(d)
                wins[hidden_word] = wins.get(hidden_word, 0) + int(r)
                
                if print_counter % 50 == 0:
                    print('''\rwins : %d \t total games : %d \t total unique games : %d \t win rate : %.03f%%''' \
                          %(sum(wins.values()), sum(total_games.values()), len(total_games), 100*sum(wins.values())/sum(total_games.values())), end='', flush=True)
                print_counter += 1
                
    if len(total_games) == len(dataset):
        print('''\rwins : %d \t total games : %d \t total unique games : %d \t win rate : %.03f%%''' \
          %(sum(wins.values()), sum(total_games.values()), len(total_games), 100*sum(wins.values())/sum(total_games.values())), end='', flush=True)
                
        end_time = time.time()
        break

# The true win rate is a better metric
true_win_rate = 0
for word, win in wins.items():
    true_win_rate += win / total_games[word]

# PRINT RESULTS
print('\n--------------------------------------------------------')
print(f'True win rate \t\t\t: \t {100*true_win_rate/len(total_games):.03f}%')
print(f'Time take to run {sum(total_games.values())} games \t: \t {(end_time-start_time):.03f} s')
print(f'Mean time per game \t\t: \t {1000* (end_time-start_time) / sum(total_games.values()):.03f} ms')

wins : 1590 	 total games : 2404 	 total unique games : 2273 	 win rate : 66.140%
--------------------------------------------------------
True win rate 			: 	 66.432%
Time take to run 2404 games 	: 	 12.715 s
Mean time per game 		: 	 5.289 ms


### Conclusion
We can see that our model with the heuristic policy performs fairly well even on the validation dataset. However, one can further try improving the model by using Reinforcement Learning techniques. This can be a future area of research.