New Sample from https://qiita.com/y_yoko/items/07b9e3e8d4a43c61d39f

Solves the problem

$$
f(x) = (x_1 - a)^2 + (x_2 -b)^2+(x_3-c)^2+(x_4-d)^2
$$

which is minimized by picking $x = (a,b,c,d)$

In our sample problem we can choose moves from $-4, ..., 4$ where $\pm n \mapsto x_n \pm 1$


In [1]:
from math import *
from copy import copy

import numpy as np
import random


class environment:
    def __init__(self):
        self.a = 0.
        self.b = 2.
        self.c = 1.
        self.d = 2.
        self.x1 = -12.
        self.x2 = 14.
        self.x3 = -13.
        self.x4 = 9.
        self.f = self.func()

    def func(self):
        return (self.x1 - self.a)**2+(self.x2 - self.b)**2+(self.x3 - self.c)**2+(self.x4 - self.d)**2

    def step(self, action):
        f_before = self.func()

        if action == 1:
            self.x1 += 1.
        elif action == -1:
            self.x1 -= 1.
        elif action ==  2:
            self.x2 += 1.
        elif action == -2:
            self.x2 -= 1.
        elif action ==  3:
            self.x3 += 1.
        elif action == -3:
            self.x3 -= 1.
        elif action == 4:
            self.x4 += 1.
        elif action == -4:
            self.x4 -= 1.

        f_after = self.func()
        if f_after == 0:
            reward = 100
            terminal = True
        else:
            reward = 1 if abs(f_before) - abs(f_after) > 0 else -1
            terminal = False

        return f_after, reward, terminal, [self.x1, self.x2, self.x3, self.x4]

    def action_space(self):
        return [-4, -3, -2, -1, 0, 1, 2, 3, 4]

    def sample(self):
        return np.random.choice([-4, -3, -2, -1, 0, 1, 2, 3, 4])

class Node():
    def __init__(self, parent, action):
        self.num_visits = 1
        self.reward = 0.
        self.children = []
        self.parent = parent
        self.action = action

    def update(self, reward):
        self.reward += reward
        self.visits += 1

    def __repr__(self):
        s="Reward/Visits =  %.1f/%.1f (Child %d)"%(self.reward, self.num_visits, len(self.children))
        return s

def ucb(node):
    return node.reward / node.num_visits + sqrt(log(node.parent.num_visits)/node.num_visits)

def reward_rate(node):
    return node.reward / node.num_visits

SHOW_INTERMEDIATE_RESULTS = False

env = environment()

num_mainloops = 20
max_playout_depth = 5
num_tree_search = 180

best_sum_reward = -inf
best_acdtion_sequence = []
best_f = 0
best_x = []

for _ in range(num_mainloops):
    root = Node(None, None)


    for run_no in range(num_tree_search):
        env_copy = copy(env)

        terminal = False
        sum_reward = 0

        # 1) Selection
        current_node = root
        while len(current_node.children) != 0:
                current_node = max(current_node.children, key = ucb)
                _, reward, terminal, _ = env_copy.step(current_node.action)
                sum_reward += reward

        # 2) Expansion
        if not terminal:
            possible_actions = env_copy.action_space()
            current_node.children = [Node(current_node, action) for action in possible_actions]     

        # Routine for each children hereafter

        for c in current_node.children:
            # 3) Playout
            env_playout = copy(env_copy)
            sum_reward_playout = 0
            action_sequence = []

            _, reward, terminal, _ = env_playout.step(c.action)
            sum_reward_playout += reward
            action_sequence.append(c.action)

            while not terminal:
                action = env_copy.sample()
                _, reward, terminal, _ = env_playout.step(action)
                sum_reward_playout += reward
                action_sequence.append(action)

                if len(action_sequence) > max_playout_depth:
                    break

            if terminal:
                print("Terminal reached during a playout. #########")

            # 4) Backpropagate
            c_ = c
            while c_:
                c_.num_visits +=1
                c_.reward += sum_reward + sum_reward_playout
                c_ = c_.parent

    #Decision
    current_node = root
    action_sequence = []
    sum_reward = 0
    env_copy = copy(env)

    while len(current_node.children) != 0:
            current_node = max(current_node.children, key = reward_rate)
            action_sequence.append(current_node.action)

    for action in action_sequence:
        _, reward, terminal, _ = env_copy.step(action)
        sum_reward += reward
        if terminal:
            break

    f, _, _, x = env_copy.step(0)       

    if SHOW_INTERMEDIATE_RESULTS == True:
        print("Action sequence: ", str(action_sequence))
        print("Sum_reward: ", str(sum_reward))

        print("f, x (original): ", env.f , str([env.x1, env.x2]))
        print("f, x (after MCT): ", str(f), str(x) )
        print("----------")

    if sum_reward > best_sum_reward:
        print(f'updated best: {x}')
        best_sum_reward = sum_reward
        best_action_sequence = action_sequence
        best_f = f
        best_x = x

print("Best Action sequence: ", str(best_action_sequence))
print("Action sequence length: ", str(len(best_action_sequence)))
print("Best Sum_reward: ", str(best_sum_reward))
print("f, x (original): ", env.f , str([env.x1, env.x2, env.x3, env.x4]))
print("f, x (after MCTS): ", str(best_f), str(best_x) )


updated best: [-2.0, 2.0, -5.0, -1.0]
Terminal reached during a playout. #########
Terminal reached during a playout. #########
updated best: [0.0, 2.0, 1.0, 2.0]
Terminal reached during a playout. #########
Terminal reached during a playout. #########
Terminal reached during a playout. #########
Terminal reached during a playout. #########
Terminal reached during a playout. #########
Terminal reached during a playout. #########
Terminal reached during a playout. #########
Terminal reached during a playout. #########
Best Action sequence:  [1, 3, 3, 1, -2, 1, 3, 1, -2, -3, 1, 1, -4, 3, 3, -4, -2, -3, -1, -2, -2, -2, 3, 1, -4, 3, 1, 1, 3, -2, -2, -4, -3, -4, 3, -3, -4, -3, -1, 1, -4, -3, -4, -2, 4, -2, -4, 2, 1, -4, 3, -3, 3, 3, 3, -3, -2, -2, 1, 4, 4, -4, -2, 3, 3, 3, 1, 2, -1, -4, -4, 4, -4, 3, 3, 4, 1, -1, 1, 4, -3, 3, 3, -2, 4, -4, 3, 1, -3, -4, -1, 1, 4, -1, -4, -4, 1, 3, 4, -4, -2, -2, 2, -1, 0, -3, -1, 1, 4, 3, 3, -4, -4, 1, -1, -4, 4, -1, 2, 4, -2, 4, -1, 1, -3, 0, 1, 3, -4, -1,