In [11]:
from model.game import Action
from ai.treeSearch import *
from ai.utils import *
from ai.agent import AlphaZero
from copy import deepcopy
import numpy as np
import ai.config as config


In [12]:
class RejectedAction(Exception):
    pass

In [14]:
np.random.seed(5)

In [17]:
import time
import cufflinks as cf
import random
import plotly.graph_objects as go
from IPython.display import display
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
from ai.model import NeuralNetwork
from model.internationalGame import InternationalGame

In [18]:
init_notebook_mode(connected=True)
cf.go_offline()

class PlayGround:
    def __init__(self):
        self.action_space_shape = len(get_action_space(10, 10))
        self.current_model = NeuralNetwork(config.REG_CONST, config.LEARNING_RATE, 
                                      (10,10,25),   self.action_space_shape, config.HIDDEN_CNN_LAYERS)

        self.best_model = NeuralNetwork(config.REG_CONST, config.LEARNING_RATE, 
                                   (10,10,25),   self.action_space_shape, config.HIDDEN_CNN_LAYERS)
        
        self.best_model.model.set_weights(self.current_model.model.get_weights())
        self.exp = []
        self.turn = []
        
        self.fig = go.FigureWidget()
        self.fig.update_layout(xaxis_title="Turn", yaxis_title='Expected value')
        #self.fig.add_scatter(fill='tozeroy')
        self.fig.add_scatter()

    def plot_figure(self):
        display(self.fig)
        
    def update_plot(self):
        with self.fig.batch_update():
            self.fig.data[0].x = self.turn
            self.fig.data[0].y = self.exp
            
    def apply_and_update(self, game, playing_agent: AlphaZero, other_agent: AlphaZero, action: Action, action_id: int):
        print(game.grid)
        print(action)
        if not game.isLegalAction(action):
            print(game.currentTurn)
            actions = game.getAllPossibleActions()
            for act in actions:
                print(act)
            raise RejectedAction("your fucked bro!")
        game.applyAction(action)
        playing_agent.update_root(action_id)
        if other_agent is not None:
            other_agent.update_root(action_id)

    def play_best(self, goes_first=True):
        current_Agent = AlphaZero(config.MCTS_SIMS, config.CPUCT, self.current_model, pov=0, name="Current AlphaZero")
        best_Agent = AlphaZero(config.MCTS_SIMS, config.CPUCT, self.best_model, pov=1, name="Best AlphaZero")
        if goes_first:
            players = [current_Agent, best_Agent]
        else:
            players = [best_Agent, current_Agent]
            best_Agent.pov, current_Agent.pov = (0, 1)

        game = GameState(None, None)
        for ep in range(config.EPISODES):
            game.init()
            current_Agent.build_mcts(StateStack(deepcopy(game)))
            best_Agent.build_mcts(StateStack(deepcopy(game)))

            turn = 1

            while not game.is_terminal():
                player: AlphaZero = players[game.get_player_turn()]
                tau = 1 if turn < config.TURNS_UNTIL_TAU0 else 0
                action, action_id, state_stack, value, pi = player.train_act(tau)
                self.apply_and_update(game, current_Agent, best_Agent, action, action_id)
                turn += 1
                #print('*', end='')
                
        #return sample_builder

    def self_play(self):
        current_Agent = AlphaZero(config.MCTS_SIMS, config.CPUCT, self.current_model, pov=0, name="Current AlphaZero")
        
        game = InternationalGame(1, None, None, None)
        sample_builder = SampleBuilder()

        for ep in range(config.EPISODES):
            game.init()
            current_Agent.build_mcts(StateStack(deepcopy(GameState(game))))

            turn = 1
            player: AlphaZero = current_Agent

            while not game.end():
                player.pov = game.currentTurn
                tau = 1 if turn < config.TURNS_UNTIL_TAU0 else 0
                try:
                    action, action_id, state_stack, value, pi = player.train_act(tau)
                except KeyError as e:
                    print(e)
                    print(player.mcts.root.edges)

                if player.pov == 1:
                    self.turn.append(turn)
                    self.exp.append(value)
                self.apply_and_update(game, current_Agent, None, action, action_id)
                self.update_plot()
                sample_builder.commit_move(state_stack, pi)
                turn += 1
                print('*', end='')

            value = evaluate(game)
            sample_builder.commit_sample(value, config.MAXIMIZER)

        return sample_builder
    
    def fit(self, sample_builder:SampleBuilder):
        for i in range(10):
            minibatch = random.sample(sample_builder.samples, min(config.BATCH_SIZE, len(sample_builder.samples)))
            training_states = np.array([row['state'].get_deep_representation_stack() for row in minibatch])
            training_targets = {'value_head': np.array([row['value'] for row in minibatch]),
                                'policy_head': np.array([row['policy'] for row in minibatch])}

            ret = self.current_model.fit(training_states, training_targets, epochs=config.EPOCHS, verbose=1, validation_split=0, batch_size = 32)

In [19]:
playground = PlayGround()

In [13]:
#playground.plot_figure()

In [20]:
try:
    playground.self_play()
except RejectedAction as e:
    print(e)

     1  2  3  4  5  6  7  8  9  10
 1|  .  B  .  B  .  B  .  B  .  B 
 2|  B  .  B  .  B  .  B  .  B  . 
 3|  .  B  .  B  .  B  .  B  .  B 
 4|  B  .  B  .  B  .  B  .  B  . 
 5|  .  .  .  .  .  .  .  .  .  . 
 6|  .  .  .  .  .  .  .  .  .  . 
 7|  .  W  .  W  .  W  .  W  .  W 
 8|  W  .  W  .  W  .  W  .  W  . 
 9|  .  W  .  W  .  W  .  W  .  W 
10|  W  .  W  .  W  .  W  .  W  . 

(7,6)------->>>(6,5)
*     1  2  3  4  5  6  7  8  9  10
 1|  .  B  .  B  .  B  .  B  .  B 
 2|  B  .  B  .  B  .  B  .  B  . 
 3|  .  B  .  B  .  B  .  B  .  B 
 4|  B  .  B  .  B  .  B  .  B  . 
 5|  .  .  .  .  .  .  .  .  .  . 
 6|  .  .  .  .  W  .  .  .  .  . 
 7|  .  W  .  W  .  .  .  W  .  W 
 8|  W  .  W  .  W  .  W  .  W  . 
 9|  .  W  .  W  .  W  .  W  .  W 
10|  W  .  W  .  W  .  W  .  W  . 

(4,9)------->>>(5,8)
*     1  2  3  4  5  6  7  8  9  10
 1|  .  B  .  B  .  B  .  B  .  B 
 2|  B  .  B  .  B  .  B  .  B  . 
 3|  .  B  .  B  .  B  .  B  .  B 
 4|  B  .  B  .  B  .  B  .  .  . 
 5|  .  .

*     1  2  3  4  5  6  7  8  9  10
 1|  .  B  .  B  .  B  .  B  .  B 
 2|  B  .  B  .  B  .  B  .  .  . 
 3|  .  .  .  B  .  B  .  B  .  . 
 4|  .  .  .  .  .  .  B  .  B  . 
 5|  .  .  .  .  .  .  .  B  .  . 
 6|  B  .  .  .  W  .  .  .  .  . 
 7|  .  .  .  W  .  W  .  W  .  W 
 8|  .  .  W  .  .  .  W  .  W  . 
 9|  .  W  .  W  .  .  .  W  .  W 
10|  W  .  W  .  W  .  W  .  W  . 

(3,4)------->>>(4,3)
*     1  2  3  4  5  6  7  8  9  10
 1|  .  B  .  B  .  B  .  B  .  B 
 2|  B  .  B  .  B  .  B  .  .  . 
 3|  .  .  .  .  .  B  .  B  .  . 
 4|  .  .  B  .  .  .  B  .  B  . 
 5|  .  .  .  .  .  .  .  B  .  . 
 6|  B  .  .  .  W  .  .  .  .  . 
 7|  .  .  .  W  .  W  .  W  .  W 
 8|  .  .  W  .  .  .  W  .  W  . 
 9|  .  W  .  W  .  .  .  W  .  W 
10|  W  .  W  .  W  .  W  .  W  . 

(9,4)------->>>(8,5)
*     1  2  3  4  5  6  7  8  9  10
 1|  .  B  .  B  .  B  .  B  .  B 
 2|  B  .  B  .  B  .  B  .  .  . 
 3|  .  .  .  .  .  B  .  B  .  . 
 4|  .  .  B  .  .  .  B  .  B  . 
 5|  .  

*     1  2  3  4  5  6  7  8  9  10
 1|  .  B  .  B  .  B  .  .  .  B 
 2|  .  .  B  .  .  .  .  .  B  . 
 3|  .  .  .  B  .  B  .  B  .  . 
 4|  B  .  .  .  .  .  B  .  B  . 
 5|  .  B  .  .  .  .  .  B  .  . 
 6|  B  .  .  .  W  .  .  .  .  . 
 7|  .  W  .  W  .  W  .  W  .  W 
 8|  .  .  W  .  W  .  W  .  W  . 
 9|  .  .  .  W  .  .  .  W  .  W 
10|  W  .  .  .  W  .  .  .  W  . 

(10,1)------->>>(9,2)
*     1  2  3  4  5  6  7  8  9  10
 1|  .  B  .  B  .  B  .  .  .  B 
 2|  .  .  B  .  .  .  .  .  B  . 
 3|  .  .  .  B  .  B  .  B  .  . 
 4|  B  .  .  .  .  .  B  .  B  . 
 5|  .  B  .  .  .  .  .  B  .  . 
 6|  B  .  .  .  W  .  .  .  .  . 
 7|  .  W  .  W  .  W  .  W  .  W 
 8|  .  .  W  .  W  .  W  .  W  . 
 9|  .  W  .  W  .  .  .  W  .  W 
10|  .  .  .  .  W  .  .  .  W  . 

(1,6)------->>>(2,7)
*     1  2  3  4  5  6  7  8  9  10
 1|  .  B  .  B  .  .  .  .  .  B 
 2|  .  .  B  .  .  .  B  .  B  . 
 3|  .  .  .  B  .  B  .  B  .  . 
 4|  B  .  .  .  .  .  B  .  B  . 
 5|  . 

*     1  2  3  4  5  6  7  8  9  10
 1|  .  B  .  .  .  .  .  .  .  B 
 2|  .  .  .  .  .  .  .  .  .  . 
 3|  .  B  .  B  .  .  .  B  .  B 
 4|  B  .  B  .  B  .  B  .  B  . 
 5|  .  B  .  .  .  B  .  B  .  W 
 6|  B  .  W  .  W  .  .  .  W  . 
 7|  .  W  .  W  .  W  .  W  .  W 
 8|  .  .  W  .  W  .  W  .  .  . 
 9|  .  .  .  W  .  .  .  W  .  W 
10|  .  .  .  .  W  .  .  .  .  . 

(1,2)------->>>(2,3)
*     1  2  3  4  5  6  7  8  9  10
 1|  .  .  .  .  .  .  .  .  .  B 
 2|  .  .  B  .  .  .  .  .  .  . 
 3|  .  B  .  B  .  .  .  B  .  B 
 4|  B  .  B  .  B  .  B  .  B  . 
 5|  .  B  .  .  .  B  .  B  .  W 
 6|  B  .  W  .  W  .  .  .  W  . 
 7|  .  W  .  W  .  W  .  W  .  W 
 8|  .  .  W  .  W  .  W  .  .  . 
 9|  .  .  .  W  .  .  .  W  .  W 
10|  .  .  .  .  W  .  .  .  .  . 

(9,10)------->>>(8,9)
*     1  2  3  4  5  6  7  8  9  10
 1|  .  .  .  .  .  .  .  .  .  B 
 2|  .  .  B  .  .  .  .  .  .  . 
 3|  .  B  .  B  .  .  .  B  .  B 
 4|  B  .  B  .  B  .  B  .  B  . 
 5|  . 

*     1  2  3  4  5  6  7  8  9  10
 1|  .  .  .  .  .  .  .  .  .  . 
 2|  .  .  .  .  .  .  .  .  B  . 
 3|  .  B  .  B  .  .  .  B  .  B 
 4|  B  .  .  .  B  .  .  .  B  . 
 5|  .  .  .  .  .  W  .  .  .  W 
 6|  B  .  .  .  .  .  W  .  W  . 
 7|  .  W  .  B  .  .  .  W  .  . 
 8|  .  .  W  .  W  .  .  .  W  . 
 9|  .  .  .  W  .  .  .  W  .  . 
10|  .  .  .  .  W  .  .  .  .  . 

(7,4)------->>>(9,6)
*     1  2  3  4  5  6  7  8  9  10
 1|  .  .  .  .  .  .  .  .  .  . 
 2|  .  .  .  .  .  .  .  .  B  . 
 3|  .  B  .  B  .  .  .  B  .  B 
 4|  B  .  .  .  B  .  .  .  B  . 
 5|  .  .  .  .  .  W  .  .  .  W 
 6|  B  .  .  .  .  .  W  .  W  . 
 7|  .  W  .  .  .  .  .  W  .  . 
 8|  .  .  W  .  .  .  .  .  W  . 
 9|  .  .  .  W  .  B  .  W  .  . 
10|  .  .  .  .  W  .  .  .  .  . 

(10,5)------->>>(8,7)
*     1  2  3  4  5  6  7  8  9  10
 1|  .  .  .  .  .  .  .  .  .  . 
 2|  .  .  .  .  .  .  .  .  B  . 
 3|  .  B  .  B  .  .  .  B  .  B 
 4|  B  .  .  .  B  .  .  .  B  . 
 5|  . 

*     1  2  3  4  5  6  7  8  9  10
 1|  .  .  .  .  .  .  .  .  .  . 
 2|  .  .  .  .  .  .  W  .  B  . 
 3|  .  .  .  .  .  .  .  .  .  B 
 4|  .  .  B  .  .  .  W  .  .  . 
 5|  .  B  .  .  .  .  .  .  .  W 
 6|  B  .  .  .  W  .  .  .  W  . 
 7|  .  W  .  .  .  .  .  W  .  W 
 8|  .  .  W  .  .  .  .  .  .  . 
 9|  .  .  .  .  .  .  .  W  .  . 
10|  .  .  .  .  .  .  .  .  .  . 

(2,7)------->>>(1,8)
*     1  2  3  4  5  6  7  8  9  10
 1|  .  .  .  .  .  .  .  W  .  . 
 2|  .  .  .  .  .  .  .  .  B  . 
 3|  .  .  .  .  .  .  .  .  .  B 
 4|  .  .  B  .  .  .  W  .  .  . 
 5|  .  B  .  .  .  .  .  .  .  W 
 6|  B  .  .  .  W  .  .  .  W  . 
 7|  .  W  .  .  .  .  .  W  .  W 
 8|  .  .  W  .  .  .  .  .  .  . 
 9|  .  .  .  .  .  .  .  W  .  . 
10|  .  .  .  .  .  .  .  .  .  . 

(4,3)------->>>(5,4)
*     1  2  3  4  5  6  7  8  9  10
 1|  .  .  .  .  .  .  .  W  .  . 
 2|  .  .  .  .  .  .  .  .  B  . 
 3|  .  .  .  .  .  .  .  .  .  B 
 4|  .  .  .  .  .  .  W  .  .  . 
 5|  .  