In [1]:
import torch 
import numpy as np
from tqdm import tqdm
from create_training_examples import *
from model import build_model

In [2]:
def simulate_test_game(word, model, device, max_wrong_guesses=6, verbose=2):
    '''
    Play a game of hangman with the given word using the model.
    Inputs:
        word: the word to guess
        model: the model to use for guessing
        max_wrong_guesses: the maximum number of wrong guesses allowed
    Returns:
        True if the word was guessed correctly, False otherwise
    '''
    word_ids = {}
    for i, c in enumerate(word):
        if c not in word_ids:
            word_ids[c] = []
        word_ids[c].append(i)

    guessed_chars = {}
    encoded_word = '*' * len(word)
    num_guesses = 0


    while encoded_word != word and num_guesses < max_wrong_guesses:
        if verbose > 0:
            print(f'Current word: {encoded_word}')
            print(f'Guesses so far: {guessed_chars}')

        final_chr, final_prob = None, None
        for i, c in enumerate(encoded_word):
            copy_word = encoded_word
            if c == '*':
                copy_word = list(encoded_word)
                copy_word[i] = MASK_CHAR
                copy_word = ''.join(copy_word)

                masked_info  = create_single_masked_word(copy_word, i)
                if not masked_info:
                    continue

                tok, _, idx, pad_mask = masked_info
                tokens_tensor = torch.tensor(tok).unsqueeze(0).to(device)
                mask_idx_tensor = torch.tensor(idx).to(device)
                pad_mask_tensor = torch.tensor(pad_mask).unsqueeze(0).to(device)

                with torch.no_grad():
                    logits = model(tokens_tensor, mask_idx_tensor, pad_mask_tensor)
                    pred_idx = logits.argmax(dim=1).item()
                    predicted_char = ID2CHAR[pred_idx]

                    while predicted_char in guessed_chars:
                        logits[0, pred_idx] = -float('inf')
                        pred_idx = logits.argmax(dim=1).item()
                        predicted_char = ID2CHAR[pred_idx]

                    if final_chr is None or logits[0, pred_idx].item() > final_prob:
                        final_chr = predicted_char
                        final_prob = logits[0, pred_idx].item()

        if final_chr in word:
            for idx in word_ids[final_chr]:
                encoded_word = encoded_word[:idx] + final_chr + encoded_word[idx + 1:]
        else:
            num_guesses += 1
        guessed_chars[final_chr] = True

        if verbose > 1:
            print(f'Guessing character: {final_chr}')
            print(f'Hangman state:', encoded_word)
            print(f'Number of wrong guesses: {num_guesses}')
    if verbose > 0:
        if encoded_word == word:
            print(f'Word guessed correctly: {word}')
            print('You win!')
        else:
            print(f'Word not guessed: {word}')
            print('You lose!')
    return encoded_word == word

In [8]:
# Set up the device 
if torch.backends.mps.is_available():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

model = build_model(d_model=256, num_encoder_layers=4)
model.load_state_dict(torch.load('model/checkpoint_epoch_6.pt', map_location=device))
model.to(device)

simulate_test_game('dishallucination', model, device, max_wrong_guesses=6, verbose=2)

Current word: ****************
Guesses so far: {}
Guessing character: y
Hangman state: ****************
Number of wrong guesses: 1
Current word: ****************
Guesses so far: {'y': True}
Guessing character: s
Hangman state: **s*************
Number of wrong guesses: 1
Current word: **s*************
Guesses so far: {'y': True, 's': True}
Guessing character: i
Hangman state: *is******i***i**
Number of wrong guesses: 1
Current word: *is******i***i**
Guesses so far: {'y': True, 's': True, 'i': True}
Guessing character: a
Hangman state: *is*a****i*a*i**
Number of wrong guesses: 1
Current word: *is*a****i*a*i**
Guesses so far: {'y': True, 's': True, 'i': True, 'a': True}
Guessing character: t
Hangman state: *is*a****i*ati**
Number of wrong guesses: 1
Current word: *is*a****i*ati**
Guesses so far: {'y': True, 's': True, 'i': True, 'a': True, 't': True}
Guessing character: d
Hangman state: dis*a****i*ati**
Number of wrong guesses: 1
Current word: dis*a****i*ati**
Guesses so far: {'y': True, 

True

In [None]:
# Test the accuracy on the training set
total_correct = 0
train_words = open('data/train_data.txt').read().splitlines()

# Get the random 2500 words from the training set
train_words = np.random.choice(train_words, size=2500, replace=False).tolist()
for word in tqdm(train_words):
    if simulate_test_game(word, model, device, max_wrong_guesses=6, verbose=0):
        total_correct += 1
print(f'Total correct guesses: {total_correct}')
print(f'Accuracy: {total_correct / len(train_words) * 100:.2f}%')

100%|██████████| 2500/2500 [15:40<00:00,  2.66it/s]

Total correct guesses: 1286
Accuracy: 51.44%





In [None]:
# Test the accuracy on the test set
total_correct = 0
test_words = open('data/final_test_words.txt').read().splitlines()
for word in tqdm(test_words):
    if simulate_test_game(word, model, device, max_wrong_guesses=6, verbose=0):
        total_correct += 1
print(f'Total correct guesses: {total_correct}')
print(f'Accuracy: {total_correct / len(test_words) * 100:.2f}%')

100%|██████████| 2512/2512 [17:02<00:00,  2.46it/s]

Total correct guesses: 1331
Accuracy: 52.99%



