In [10]:
import copy
import numpy as np
import matplotlib.pyplot as plt
import gym

In [40]:
# tree node class definition
class TreeNode():
    def __init__(self, env, parent, done):
        # init associated board state
        self.env = env
        
        # init is node terminal flag
        self.is_terminal = done
        self.is_fully_expanded = self.is_terminal
        self.parent = parent
        self.visits = 0
        self.score = 0
        self.children = {}

In [54]:
# MCTS class definition
class MCTS_FOMDP():
    # search for the best move in the current position
    def search(self, initial_state, iteration=1000):
        # create root node
        self.root = TreeNode(initial_state, None, False)

        # walk through 1000 iterations
        for iteration in range(iteration):
            # select a node (selection phase)
            node = self.select(self.root)
            
            # scrore current node (simulation phase)
            score = self.simulation(node.env)
            
            # backpropagate results
            self.backpropagate(node, score)
        
        # pick up the best move in the current position
        try:
            return self.get_best_move(self.root, 0)
        
        except:
            pass
    
    # select most promising node
    def select(self, node):
        # make sure that we're dealing with non-terminal nodes
        while not node.is_terminal:
            # case where the node is fully expanded
            if node.is_fully_expanded:
                node = self.get_best_move(node, 2)
            
            # case where the node is not fully expanded 
            else:
                # otherwise expand the node
                return self.expand(node)
       # return node
        return node

    # expand node
    def expand(self, node):
        # gym environment의 action_space는 integer
        actions = node.env.action_space.n
        parent_config = node.env.unwrapped.spec.id
        parent_state = node.env.unwrapped.state
        parent_env = gym.make(parent_config)
        parent_env.reset()
        parent_env.unwrapped.state = parent_state

        # loop over generated states (moves)
        for action in range(actions):
            # make sure that current state (move) is not present in child nodes
            if action not in node.children:
                # create a new node
                _,_,done,_ = parent_env.step(action)
                new_node = TreeNode(parent_env, node, done)
                
                # add child node to parent's node children list (dict)
                node.children[action] = new_node
                
                # Restore parent state
                parent_env.unwrapped.state = parent_state

                # case when node is fully expanded
                if len(range(actions)) == len(node.children):
                    node.is_fully_expanded = True
                
                # return newly created node
                return new_node
        
        # debugging
        print('접근 안되는 구간입니다!!!') # 디버그 용도

    # simulate the game via making random moves until reach end of the game
    def simulation(self, env):
        total_reward = 0
        done = False
        while not done:
            next_state, reward, done, _ = env.step(env.action_space.sample())
            total_reward += reward
            if done:
                break
        return total_reward
        

    # backpropagate the number of visits and score up to the root node
    def backpropagate(self, node, score, gamma=0.99):
        # update nodes's up to root node
        while node is not None:
            # update node's score
            node.score = gamma*node.score*node.visits + score
            node.score /= (node.visits+1)
            
            # update node's visits
            node.visits += 1
            node = node.parent
            
    # select the best node basing on UCB1 formula
    def get_best_move(self, node, C_const=2):
        best_score = -np.inf
        best_moves = []

        for child_node in node.children.values():        
            # get move score using UCB formula
            #move_score = current_player*child_node.score/child_node.visits+\
            move_score = child_node.score+\
                C_const*np.sqrt(np.log(node.visits/child_node.visits))

            # better move has been found
            if move_score > best_score:
                best_score = move_score
                best_moves = [child_node]

            # found as good move as alread avilable
            elif move_score == best_score:
                best_moves.append(child_node)

        # return one of the best moves randomly
        return np.random.choice(best_moves)

In [55]:
mcts = MCTS_FOMDP()

env = gym.make('CartPole-v0')
state = env.reset()
done = False
action_list = []
total_reward = 0
while True:
    action = mcts.search(env)
    next_state, reward, done, _ = env.step(action)
    total_reward += reward 
    action_list.append(action)
    if done:
        break
print(f'Total reward: {total_reward}')

  logger.warn(


AssertionError: <__main__.TreeNode object at 0x7f6d9bb6d580> (<class '__main__.TreeNode'>) invalid

In [28]:
env.reset()

array([ 0.01197783, -0.02286854, -0.03317878, -0.01879819], dtype=float32)

In [47]:
env.unwrapped.state

array([-0.0306542 ,  0.00124451, -0.04181773,  0.03405016])

In [50]:
env.unwrapped.action_space

Discrete(2)

In [51]:
env.unwrapped.__dict__

{'gravity': 9.8,
 'masscart': 1.0,
 'masspole': 0.1,
 'total_mass': 1.1,
 'length': 0.5,
 'polemass_length': 0.05,
 'force_mag': 10.0,
 'tau': 0.02,
 'kinematics_integrator': 'euler',
 'theta_threshold_radians': 0.20943951023931953,
 'x_threshold': 2.4,
 'action_space': Discrete(2),
 'observation_space': Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32),
 'screen': None,
 'clock': None,
 'isopen': True,
 'state': array([-0.0306542 ,  0.00124451, -0.04181773,  0.03405016]),
 'steps_beyond_done': None,
 'spec': EnvSpec(entry_point='gym.envs.classic_control:CartPoleEnv', reward_threshold=195.0, nondeterministic=False, max_episode_steps=200, order_enforce=True, kwargs={}, namespace=None, name='CartPole', version=0),
 '_np_random': RandomNumberGenerator(PCG64) at 0x7F6D9BC16220}