In [19]:
import gymnasium as gym
from gymnasium.spaces import MultiDiscrete, Discrete
from gymnasium import Env
import numpy as np
from itertools import permutations

# 0 means empty - add 0 at 0th index
numbers = [12,21]
numbers.insert(0, 0) 

# possible AB's - add 0 at 0th index
AB = ['0', '0A2B', '2A0B']

# calculate number of AB
def determine_AB(a, b):
        a_digits = [a//10, a%10]
        b_digits = [b//10, b%10]
        A = 0
        B = 0
        for count, digit in enumerate(a_digits):
            if digit == b_digits[count]:
                A += 1
            elif digit in b_digits:
                B += 1
        det_AB = AB.index(str(A)+'A'+str(B)+'B')
        return det_AB

class BullsnCows(Env):

    metadata = {"render.modes": ["human"]}
    '''
    State Space:
        num_guesses*2 grid, each row 4-digit sequence + no. of AB
        compressed into an array - so starting from zero, even indexes are the 4-digit sequence, while odd indexes are the no. of AB
        Initialize with all 0's (empty)
    Action Space:
        4-digit sequence
    Reward:
        +1 for 4A, -1 for not getting 4A within num_guesses tries, 0 otherwise
    
    '''
    def __init__(self):
        super(BullsnCows, self).__init__()
        # state space
        obs_row = np.array([len(numbers), len(AB)])
        self.observation_space = MultiDiscrete(np.concatenate(np.repeat(obs_row[np.newaxis, :], 3, axis=0)))
        # action space - size is total number of possible numbers
        self.action_space = Discrete(len(numbers)-1)
        # initialize number of guesses
        self.num_guesses = 0
        # initialize numbers that have already been guessed
        self.guesses = []
        # initialize starting state - start with all zeros (empty)
        self.state = np.zeros(2*3, dtype=int)
        # generate correct word - exclude 0
        self.correct_word = np.random.choice(numbers[1:len(numbers)])
        # initialize terminated, truncated
        self.terminated = False
        self.truncated = False

    def step(self, action):
        # append to guesses
        # since 0 means empty, we add one
        action = action+1
        self.guesses.append(action)

        # go to the number guessed for the correct turn
        self.state[2*self.num_guesses] = action
        # map action to corresponding number and calculate to the number of AB
        check_AB = determine_AB(numbers[action], self.correct_word)
        self.state[2*self.num_guesses + 1] = check_AB
        # increase number of guesses by 1
        self.num_guesses += 1

        # calculate reward - if guessed the correct word, terminate
        if check_AB == AB.index('2A0B'):
            reward = 1
            self.terminated = True
        # if it is the final guess, terminate
        elif self.num_guesses == 3:
            reward = -1
            self.terminated = True
        else:
            # we want to guess as quickly as possible, so penalize wrong guesses
            reward = 0
        return self.state, reward, self.terminated, self.truncated, {}

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        # reset to original states
        # initialize number of guesses
        self.num_guesses = 0
        # initialize guessed numbers
        self.guesses = []
        # initialize starting state
        self.state = np.zeros(2*3, dtype=int)
        # generate correct word
        self.correct_word = np.random.choice(numbers[1:len(numbers)])
        # initialize terminated
        self.terminated = False

        return self.state, {}
    
    def render(self, mode="human"):
        print("Correct Word:", self.correct_word)
        for n, guess, state in zip([x+1 for x in range(self.num_guesses)], self.guesses, [self.state[2*i+1] for i in range(self.num_guesses)]):
            print(f'Guess {n}: {numbers[guess]}, AB: {AB[state]}')


In [20]:
env = BullsnCows()
print(env.observation_space.nvec)
print(env.action_space.n)

[3 3 3 3 3 3]
2


In [21]:
env = BullsnCows()
episodes = 20
for episode in range(1, episodes+1):
    state, _ = env.reset()
    score = 0

    while not env.terminated:
        action = env.action_space.sample()
        state, reward, terminated, truncated, info = env.step(action)
        env.render()
        score += reward

    print(f'Episode: {episode}, Score: {score}')

Correct Word: 12
Guess 1: 12, AB: 2A0B
Episode: 1, Score: 1
Correct Word: 21
Guess 1: 21, AB: 2A0B
Episode: 2, Score: 1
Correct Word: 12
Guess 1: 12, AB: 2A0B
Episode: 3, Score: 1
Correct Word: 12
Guess 1: 21, AB: 0A2B
Correct Word: 12
Guess 1: 21, AB: 0A2B
Guess 2: 12, AB: 2A0B
Episode: 4, Score: 1
Correct Word: 21
Guess 1: 21, AB: 2A0B
Episode: 5, Score: 1
Correct Word: 12
Guess 1: 21, AB: 0A2B
Correct Word: 12
Guess 1: 21, AB: 0A2B
Guess 2: 12, AB: 2A0B
Episode: 6, Score: 1
Correct Word: 12
Guess 1: 12, AB: 2A0B
Episode: 7, Score: 1
Correct Word: 12
Guess 1: 21, AB: 0A2B
Correct Word: 12
Guess 1: 21, AB: 0A2B
Guess 2: 12, AB: 2A0B
Episode: 8, Score: 1
Correct Word: 21
Guess 1: 21, AB: 2A0B
Episode: 9, Score: 1
Correct Word: 21
Guess 1: 12, AB: 0A2B
Correct Word: 21
Guess 1: 12, AB: 0A2B
Guess 2: 12, AB: 0A2B
Correct Word: 21
Guess 1: 12, AB: 0A2B
Guess 2: 12, AB: 0A2B
Guess 3: 12, AB: 0A2B
Episode: 10, Score: -1
Correct Word: 12
Guess 1: 21, AB: 0A2B
Correct Word: 12
Guess 1: 21, AB

In [22]:
from stable_baselines3 import DQN, PPO, A2C
from stable_baselines3.common.evaluation import evaluate_policy
import os
import torch

logdir_dqn = "1221_logs/dqn/"
logdir_ppo = "1221_logs/ppo/"
logdir_a2c = "1221_logs/a2c/"

for logdir in [logdir_a2c, logdir_dqn, logdir_ppo]:
	if not os.path.exists(logdir):
		os.makedirs(logdir)

env = BullsnCows()
device = torch.device("cuda")

model_dqn = DQN("MlpPolicy", env, verbose=1, learning_starts=5000, tensorboard_log=logdir_dqn, device=device)
model_ppo = PPO("MlpPolicy", env, verbose=1, tensorboard_log=logdir_ppo, device=device)
model_a2c = A2C("MlpPolicy", env, verbose=1, tensorboard_log=logdir_a2c, device=device)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [23]:
model_dqn.learn(total_timesteps=10000, log_interval=4)
model_dqn.save("dqn_BullsnCows_1221")

Logging to 1221_logs/dqn/DQN_3
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.75     |
|    ep_rew_mean      | 0.5      |
|    exploration_rate | 0.993    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 642      |
|    time_elapsed     | 0        |
|    total_timesteps  | 7        |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.88     |
|    ep_rew_mean      | 0.5      |
|    exploration_rate | 0.986    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 1056     |
|    time_elapsed     | 0        |
|    total_timesteps  | 15       |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.75     |
|    ep_rew_mean      | 0.667    |
|    exploration_rate | 0.98     |
| time/               | 

In [89]:
state, _ = env.reset()
while not env.terminated:
    action, _states = model_dqn.predict(state, deterministic=True)
    state, reward, terminated, truncated, info = env.step(action)
    env.render()

Correct Word: 21
Guess 1: 12, AB: 0A2B
Correct Word: 21
Guess 1: 12, AB: 0A2B
Guess 2: 21, AB: 2A0B


In [25]:
model_ppo.learn(total_timesteps=10000, log_interval=4)
model_ppo.save("ppo_BullsnCows_1221")

Logging to 1221_logs/ppo/PPO_3
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 1.62        |
|    ep_rew_mean          | 1           |
| time/                   |             |
|    fps                  | 313         |
|    iterations           | 4           |
|    time_elapsed         | 26          |
|    total_timesteps      | 8192        |
| train/                  |             |
|    approx_kl            | 0.044624418 |
|    clip_fraction        | 0.294       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.498      |
|    explained_variance   | 0.036       |
|    learning_rate        | 0.0003      |
|    loss                 | 0.0658      |
|    n_updates            | 30          |
|    policy_gradient_loss | -0.0505     |
|    value_loss           | 0.182       |
-----------------------------------------


In [26]:
state, _ = env.reset()
while not env.terminated:
    action, _states = model_ppo.predict(state, deterministic=True)
    state, reward, terminated, truncated, info = env.step(action)
    env.render()

Correct Word: 21
Guess 1: 21, AB: 2A0B


In [27]:
model_a2c.learn(total_timesteps=10000, log_interval=4)
model_a2c.save("a2c_BullsnCows_1221")

Logging to 1221_logs/a2c/A2C_3
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 2         |
|    ep_rew_mean        | 0.8       |
| time/                 |           |
|    fps                | 254       |
|    iterations         | 4         |
|    time_elapsed       | 0         |
|    total_timesteps    | 20        |
| train/                |           |
|    entropy_loss       | -0.687    |
|    explained_variance | -1.08e+04 |
|    learning_rate      | 0.0007    |
|    n_updates          | 3         |
|    policy_loss        | 0.325     |
|    value_loss         | 0.335     |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 1.86      |
|    ep_rew_mean        | 0.81      |
| time/                 |           |
|    fps                | 278       |
|    iterations         | 8         |
|    time_elapsed       | 0         |
|    total_timestep

In [28]:
state, _ = env.reset()
while not env.terminated:
    action, _states = model_a2c.predict(state, deterministic=True)
    state, reward, terminated, truncated, info = env.step(action)
    env.render()

Correct Word: 21
Guess 1: 21, AB: 2A0B


In [88]:
%load_ext tensorboard
%tensorboard --logdir='1221_logs' --port 6007


The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6007 (pid 175112), started 0:00:01 ago. (Use '!kill 175112' to kill it.)