In [72]:
# Imports 
import numpy as np
from english_words import english_words_set
import gym
from gym import spaces
from utils.exploration.base_exploration_model import BaseExplorationModel
from gym.utils.env_checker import check_env
from collections import deque
from utils.infrastructure.logger import Logger
from utils.infrastructure.schedule_utils import *

from collections import defaultdict
import re

In [99]:
class WordleSimple(gym.Env): 
    
    def __init__(self,
                 answer: str = None, 
                 valid_words: list = None, 
                 keep_answer_on_reset: bool = False, 
                 logdir: str = 'data'): 
        
        # Store attributes 
        assert(valid_words is not None), 'Must pass valid words'
        self.valid_words = valid_words
        self.n_valid_words = len(self.valid_words)
        self.answer = answer if answer is not None else np.random.choice(self.valid_words)
        self.keep_answer_on_reset = keep_answer_on_reset
        
        # Action + Observation Space
        self.action_space = gym.spaces.Discrete(self.n_valid_words)
        self.observation_space = gym.spaces.Box(low = 0, 
                                                high = 1, 
                                                shape = (self.n_valid_words,), 
                                                dtype = int)

        #  self.observation_space = gym.spaces.MultiDiscrete([2] * self.n_valid_words)
        
        # Init Stuff 
        self.state = np.ones(len(self.valid_words), dtype = int)
        self.guess_count = 0
        self.alphabet = list('abcdefghijklmnopqrstuvwxyz')
        self.possible_words = self.valid_words
        self.n_possible_words = len(self.possible_words)
        self.patterns = ['[abcdefghijklmnopqrstuvwxyz]']*5
        
        # Logging
        self.logging_freq = 500
        self.num_games = 0
        self.victory_buffer = deque(maxlen = self.logging_freq)
        self.win = False
        self.logger = Logger(logdir)   
        
        # Track yellow letters to enforce they are in the word 
        self.all_greens = set()
        self.all_yellows = set()
        self.all_grays = set()
        
    def _create_pattern(self, guess): 
        
        # Init structures to check which letters are green and which are yellow
        greens = dict(zip(range(5), ['']*5))
        yellows = defaultdict(list)
        grays = []

        # Get which words belong where 
        for idx, (guess_letter, answer_letter) in enumerate(zip(guess, self.answer)): 
            if guess_letter == answer_letter: # green letter
                greens.update({idx: guess_letter})
            elif guess_letter in self.answer: # yellow letter
                yellows[idx].append(guess_letter)
                self.all_yellows.add(guess_letter)
            else: 
                grays.append(guess_letter) # gray letter

        # Remove grays from alphabet
        self.alphabet = sorted(set(self.alphabet) - set(grays))

        for i in range(5): 

            if greens[i] != '': 

                # If we have the green letter, we should just replace the whole pattern with this
                self.patterns[i] = '[' + greens[i] + ']'

            elif len(yellows[i])  > 0: 

                # If we get another yellow, remove it from the pattern 
                for letter in yellows[i]: 
                    self.patterns[i] = self.patterns[i].replace(letter, '')

            else: 
                self.patterns[i] = '[' + ''.join(self.alphabet) + ']'


        # Combine patterns into single new pattern 
        pattern = "".join(self.patterns)
        
        return pattern

    def _compute_reward(self, guess): 
        
        # Create pattern 
        pattern = self._create_pattern(guess)
        
        print(pattern)

        # Get possible words by matching regex
        new_possible_words = [word for word in self.possible_words if re.match(pattern, word)]
        
        # Compute reward
        reward = (len(self.possible_words) - len(new_possible_words))/len(self.possible_words)
        
        # Check if won 
        won = bool(guess == self.answer)

            
        return reward, won, new_possible_words
                
    def step(self, action): 
        
        # Grab decoded word 
        guess = self.valid_words[action]
        # Compute reward
        reward, win, new_possible_words = self._compute_reward(guess)
        
        # Add win/loss penalty
        if win: 
            reward += 1
        else: 
            reward -= 1
        
        # Add possible word penalty 
        if guess not in self.possible_words: 
            reward -= 1

        # ACTUALLY GIVE IT THIS REWARD
        # reward = 1 if guess not in self.possible_words else -1
        
        # Update state
        self.state = np.array([1 if word in new_possible_words else 0 for word in self.valid_words], dtype=int)
        assert(self.state.shape == self.observation_space.shape), f'{self.state.shape}'
        self.possible_words = new_possible_words
        self.n_possible_words = len(self.possible_words)

        
        # Increment guess count 
        self.guess_count += 1
        
        # Check if done
        done = (win) or (self.guess_count == 6)
        
        # Info 
        info = {'guess_count': self.guess_count, 'won': win}
        
        self.win = win
                
        return self.state, reward, done, info
        
    def reset(self): 

        # Win Ratio logging
        self.victory_buffer.append(self.win)
        self.num_games += 1
        if not self.num_games % self.logging_freq:
            logs = {
                "win ratio": self._compute_win_ratio()
            }
            self.do_logging(logs, self.num_games)
    
       
        # Reset possible words = all valid words
        self.possible_words = self.valid_words
        self.patterns = ['[abcdefghijklmnopqrstuvwxyz]']*5
        
        # Reset alphabet, state and guess count
        self.alphabet = list('abcdefghijklmnopqrstuvwxyz')
        self.state = np.ones(len(self.valid_words), dtype = int)
        self.guess_count = 0

        # Reset answer 
        if not self.keep_answer_on_reset:
            self.answer = np.random.choice(self.valid_words)
        
        
        return self.state
 
    def _compute_win_ratio(self):
        """
        Computes the win ration of games currently in the victory buffer.
        :return: the win ratio
        """
        wins = sum(self.victory_buffer)
        return wins/len(self.victory_buffer)

    def do_logging(self, logs, num_games):
        """
        :param logs: dictionary containing values to be logged
        :param num_games: the number of games
        :return: logs values to tensorboard
        """
        print(f"Number of Games: {num_games}")
        print(f"State: \n {self.state}")
        print(f"Possible Words: \n {self.possible_words}")
        print(f"Answer: \n {self.answer}")
        for key, value in logs.items():
            print('{} : {}'.format(key, value))
            self.logger.log_scalar(value, key, num_games)
        print("\n")

In [94]:
words = [word.replace('\n', '').replace(',', '') for word in list(open('scripts/wordle_subset.txt', 'r'))]

In [95]:
env = WordleSimple(valid_words = words, answer = 'abide')

########################
logging outputs to  data
########################


In [96]:
env.step(env.valid_words.index('abuse'))

[a][b][abcdefghijklmnopqrtvwxyz][abcdefghijklmnopqrtvwxyz][e]


(array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 -0.010000000000000009,
 False,
 {'guess_count': 1, 'won': False})

In [97]:
env.step(env.valid_words.index('birth'))

[a][b][abcdefgijklmnopqvwxyz][abcdefgijklmnopqvwxyz][abcdefgijklmnopqvwxyz]


(array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 -2.0,
 False,
 {'guess_count': 2, 'won': False})

In [98]:
env.all_yellows

{'b', 'i'}