A0C (continuous mcts) on toy problem

In [1]:
from collections import defaultdict
import numpy as np
import math

In [6]:
# # tree is state, state-action nodes. Works for both discrete and continuous actions
# class StateNode:
#   def __init__(self, state):
#     self.state = state
#     self.children = {} # action -> StateActionNode
  
#   def find_random_child(self):
#     # in continuous case, sample from distribution
#     # in discrete case, choose from children
#     return None

# class StateActionNode:
#   def __init__(self, action):
#     self.action = action
#     self.children = {} # action -> StateActionNode

In [46]:
# toy problem: move towards 10
def step(s,a):
  return s+a, -abs(s+a-10) # returns next state, reward

def sample_action(s):
  return np.random.uniform(-1, 1)

In [95]:
class MCTS:
  def __init__(self, exploration_weight=1, gamma=0.95, k=1, alpha=0.5):
    self.gamma = gamma
    self.exploration_weight = exploration_weight
    self.k = k
    self.alpha = alpha
    self.N = defaultdict(int) # (s,a) -> visits
    self.Q = defaultdict(float) # (s,a) -> mean value
    self.Ns = defaultdict(int) # s -> total visits
    self.children = defaultdict(list)
  
  def run(self, s, d, n_sims):
    for _ in range(n_sims):
      self.simulate(s, d)
    return self.get_best_action(s)
  
  def simulate(self, s, d=10):
    """runs a MCTS simulation from state s to depth d"""
    if d<=0:
    #   return self.V(s) # TODO: value net
      return 0

    m = self.k * self.Ns[s] ** self.alpha # progressive widening
    if s not in self.children or len(self.children[s]) < m:
      a = sample_action(s)
      self.children[s].append(a)
    else:
      a = self.select_action(s)                           # selection
    s2, r = step(s, a)                                    # expansion
    q = r + self.gamma * self.simulate(s2, d-1)           # simulation
    self.N[(s,a)] += 1                                   
    self.Q[(s,a)] += (q-self.Q[(s,a)])/self.N[(s,a)]      # backpropagation
    self.Ns[s] += 1
    return q
  
  def select_action(self, s):
    """select action based on PUCT explore/exploit"""
    def puct(a):
      # TODO: add policy network self.pi(a|s)
      return self.Q[(s,a)] + self.exploration_weight * math.sqrt(self.Ns[s])/(self.N[(s,a)]+1)
    return max(self.children[s], key=puct)

  def get_best_action(self, s):
    return max(self.children[s], key=lambda a: self.Q[(s,a)])

In [96]:
mcts = MCTS() # for this very simple problem we can even use gamma = 0
mcts.simulate(0)

for _ in range(300):
  mcts.simulate(0)

In [97]:
mcts.get_best_action(0)

0.916248534598993

In [98]:
# play a round
s = 0 # start
n_sims = 100
n_moves = 30
search_depth = 10

for _ in range(n_moves):
  best_a = mcts.run(s, search_depth, n_sims)
  s_next, r = step(s, best_a)
  print("state", s, "action", best_a, "reward", r)
  # print("ranked actions", sorted(mcts.children[s], key=lambda a: mcts.Q[(s,a)], reverse=True))
  s = s_next


state 0 action 0.916248534598993 reward -9.083751465401008
state 0.916248534598993 action 0.9563765773781576 reward -8.12737488802285
state 1.8726251119771506 action 0.8229031812788448 reward -7.304471706744005
state 2.6955282932559954 action 0.982181275558391 reward -6.322290431185614
state 3.6777095688143864 action 0.8292299095258497 reward -5.493060521659764
state 4.506939478340236 action 0.9351034352357555 reward -4.557957086424008
state 5.442042913575992 action 0.7622983426507535 reward -3.795658743773255
state 6.204341256226745 action 0.8948850539934985 reward -2.900773689779756
state 7.099226310220244 action 0.9191120484373752 reward -1.9816616413423809
state 8.01833835865762 action 0.4064847827577298 reward -1.5751768585846513
state 8.424823141415349 action 0.7999610502232746 reward -0.7752158083613772
state 9.224784191638623 action 0.7347145990738773 reward -0.04050120928750012
state 9.9594987907125 action -0.2193597512933705 reward -0.2598609605808697
state 9.74013903941913 a