In [None]:
import gymnasium as gym
from gymnasium.spaces import MultiDiscrete, Discrete
from gymnasium import Env
import numpy as np
from itertools import permutations

# 0 means empty - add 0 at 0th index
numbers = [
    c * 10 + d 
    for c, d in permutations(range(1,4), 2)
]
numbers.insert(0, 0) 

# possible AB's - add 0 at 0th index
AB = ['0', '0A1B', '0A2B', '1A0B', '2A0B']


# calculate number of AB
def determine_AB(a, b):
        a_digits = [a//10, a%10]
        b_digits = [b//10, b%10]
        A = 0
        B = 0
        for count, digit in enumerate(a_digits):
            if digit == b_digits[count]:
                A += 1
            elif digit in b_digits:
                B += 1
        det_AB = AB.index(str(A)+'A'+str(B)+'B')
        return det_AB

class BullsnCows(Env):

    metadata = {"render.modes": ["human"]}
    '''
    State Space:
        num_guesses*2 grid, each row 4-digit sequence + no. of AB
        compressed into an array - so starting from zero, even indexes are the 4-digit sequence, while odd indexes are the no. of AB
        Initialize with all 0's (empty)
    Action Space:
        4-digit sequence
    Reward:
        +1 for 4A, -1 for not getting 4A within num_guesses tries, 0 otherwise
    
    '''
    def __init__(self):
        super(BullsnCows, self).__init__()
        # state space
        obs_row = np.array([len(numbers), len(AB)])
        self.observation_space = MultiDiscrete(np.concatenate(np.repeat(obs_row[np.newaxis, :], 3, axis=0)))
        # action space - size is total number of possible numbers
        self.action_space = Discrete(len(numbers)-1)
        # initialize number of guesses
        self.num_guesses = 0
        # initialize numbers that have already been guessed
        self.guesses = []
        # initialize starting state - start with all zeros (empty)
        self.state = np.zeros(2*3, dtype=int)
        # generate correct word - exclude 0
        self.correct_word = np.random.choice(numbers[1:len(numbers)])
        # initialize terminated, truncated
        self.terminated = False
        self.truncated = False

    def step(self, action):
        # append to guesses
        # since 0 means empty, we add one
        action = action+1
        self.guesses.append(action)

        # go to the number guessed for the correct turn
        self.state[2*self.num_guesses] = action
        # map action to corresponding number and calculate to the number of AB
        check_AB = determine_AB(numbers[action], self.correct_word)
        self.state[2*self.num_guesses + 1] = check_AB
        # increase number of guesses by 1
        self.num_guesses += 1

        # calculate reward - if guessed the correct word, terminate
        if check_AB == AB.index('2A0B'):
            reward = 1
            self.terminated = True
        # if it is the final guess, terminate
        elif self.num_guesses == 3:
            reward = -1
            self.terminated = True
        else:
            reward = 0

        return self.state, reward, self.terminated, self.truncated, {}

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        # reset to original states
        # initialize number of guesses
        self.num_guesses = 0
        # initialize guessed numbers
        self.guesses = []
        # initialize starting state
        self.state = np.zeros(2*3, dtype=int)
        # generate correct word
        self.correct_word = np.random.choice(numbers[1:len(numbers)])
        # initialize terminated
        self.terminated = False

        return self.state, {}
    
    def render(self, mode="human"):
        print("Correct Word:", self.correct_word)
        for n, guess, state in zip([x+1 for x in range(self.num_guesses)], self.guesses, [self.state[2*i+1] for i in range(self.num_guesses)]):
            print(f'Guess {n}: {numbers[guess]}, AB: {AB[state]}')


In [None]:
from stable_baselines3 import DQN, PPO, A2C
from stable_baselines3.common.evaluation import evaluate_policy
import os
import torch

logdir_dqn = "reduced_logs/dqn/"
logdir_ppo = "reduced_logs/ppo/"
logdir_a2c = "reduced_logs/a2c/"

for logdir in [logdir_a2c, logdir_dqn, logdir_ppo]:
	if not os.path.exists(logdir):
		os.makedirs(logdir)

env = BullsnCows()
device = torch.device("cuda")

model_dqn = DQN("MlpPolicy", env, verbose=1, learning_starts=50000, tensorboard_log=logdir_dqn, device=device)
model_ppo = PPO("MlpPolicy", env, verbose=1, tensorboard_log=logdir_ppo, device=device)
model_a2c = A2C("MlpPolicy", env, verbose=1, tensorboard_log=logdir_a2c, device=device)

In [None]:
model_dqn.learn(total_timesteps=100000, log_interval=4)
model_dqn.save("dqn_BullsnCows_reduced")

In [None]:
state, _ = env.reset()
while not env.terminated:
    action, _states = model_dqn.predict(state, deterministic=True)
    state, reward, terminated, truncated, info = env.step(action)
    env.render()


In [None]:
model_ppo.learn(total_timesteps=100000, log_interval=4)
model_ppo.save("ppo_BullsnCows_reduced")

In [None]:
state, _ = env.reset()
while not env.terminated:
    action, _states = model_ppo.predict(state, deterministic=True)
    state, reward, terminated, truncated, info = env.step(action)
    env.render()

In [None]:
model_a2c.learn(total_timesteps=100000, log_interval=4)
model_a2c.save("a2c_BullsnCows_reduced")

In [None]:
state, _ = env.reset()
while not env.terminated:
    action, _states = model_ppo.predict(state, deterministic=True)
    state, reward, terminated, truncated, info = env.step(action)
    env.render()

In [None]:
%load_ext tensorboard
%tensorboard --logdir='reduced_logs' --port 6007