In [2]:
from model.game import GameState, Action
from ai.MCTree import MCTree, Node, Edge
from ai.utils import ActionEncoder, StateStack, mirror_action, to_label, SampleBuilder, get_action_space
from ai.agent import AlphaZero
from copy import deepcopy
import numpy as np
import ai.config as config


In [3]:
class RejectedAction(Exception):
    pass

In [17]:
import time
import cufflinks as cf
import random
import plotly.graph_objects as go
from IPython.display import display
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
from ai.model import NeuralNetwork

init_notebook_mode(connected=True)
cf.go_offline()

class PlayGround:
    def __init__(self):
        self.action_space_shape = len(get_action_space(10, 10))
        self.current_model = NeuralNetwork(config.REG_CONST, config.LEARNING_RATE, 
                                      (10,10,20),   self.action_space_shape, config.HIDDEN_CNN_LAYERS)

        self.best_model = NeuralNetwork(config.REG_CONST, config.LEARNING_RATE, 
                                   (10,10,20),   self.action_space_shape, config.HIDDEN_CNN_LAYERS)
        
        self.best_model.model.set_weights(self.current_model.model.get_weights())
        self.exp = []
        self.turn = []
        
        self.fig = go.FigureWidget()
        self.fig.update_layout(xaxis_title="Turn", yaxis_title='Expected value')
        #self.fig.add_scatter(fill='tozeroy')
        self.fig.add_scatter()

    def plot_figure(self):
        display(self.fig)
        
    def update_plot(self):
        with self.fig.batch_update():
            self.fig.data[0].x = self.turn
            self.fig.data[0].y = self.exp
            
    def apply_and_update(self, game: GameState, playing_agent: AlphaZero, other_agent: AlphaZero, action: Action, action_id: int):
        if not game.is_legal_action(action):
            raise RejectedAction("your fucked bro!")
        game.apply_action(action)
        playing_agent.update_root(action_id)
        if other_agent is not None:
            other_agent.update_root(action_id)

    def play_best(self, goes_first=True):
        current_Agent = AlphaZero(config.MCTS_SIMS, config.CPUCT, self.current_model, pov=0, name="Current AlphaZero")
        best_Agent = AlphaZero(config.MCTS_SIMS, config.CPUCT, self.best_model, pov=1, name="Best AlphaZero")
        if goes_first:
            players = [current_Agent, best_Agent]
        else:
            players = [best_Agent, current_Agent]
            best_Agent.pov, current_Agent.pov = (0, 1)

        game = GameState(None, None)
        for ep in range(config.EPISODES):
            game.init()
            current_Agent.build_mcts(StateStack(deepcopy(game)))
            best_Agent.build_mcts(StateStack(deepcopy(game)))

            turn = 1

            while not game.is_terminal():
                player: AlphaZero = players[game.get_player_turn()]
                tau = 1 if turn < config.TURNS_UNTIL_TAU0 else 0
                action, action_id, state_stack, value, pi = player.train_act(tau)
                self.apply_and_update(game, current_Agent, best_Agent, action, action_id)
                turn += 1
                print('*', end='')
                
        #return sample_builder

    def self_play(self):
        current_Agent = AlphaZero(config.MCTS_SIMS, config.CPUCT, self.current_model, pov=0, name="Current AlphaZero")
        game = GameState(None, None)
        sample_builder = SampleBuilder()

        for ep in range(config.EPISODES):
            game.init()
            current_Agent.build_mcts(StateStack(deepcopy(game)))

            turn = 1
            player: AlphaZero = current_Agent

            while not game.is_terminal():
                player.pov = game.get_player_turn()
                tau = 1 if turn < config.TURNS_UNTIL_TAU0 else 0
                try:
                    action, action_id, state_stack, value, pi = player.train_act(tau)
                except KeyError as e:
                    print(e)
                    print(player.mcts.root.edges)

                if player.pov == config.MAXMIZER
                    self.turn.append(turn)
                    self.exp.append(value)
                self.apply_and_update(game, current_Agent, None, action, action_id)
                self.update_plot()
                sample_builder.commit_move(state_stack, pi)
                turn += 1
                print('*', end='')

            value = game.get_value()
            sample_builder.commit_sample(value, 0)

        return sample_builder
    
    def fit(self, sample_builder:SampleBuilder):
        for i in range(10):
            minibatch = random.sample(sample_builder.samples, min(config.BATCH_SIZE, len(sample_builder.samples)))
            training_states = np.array([row['state'].get_deep_representation_stack() for row in minibatch])
            training_targets = {'value_head': np.array([row['value'] for row in minibatch]),
                                'policy_head': np.array([row['policy'] for row in minibatch])}

            ret = self.current_model.fit(training_states, training_targets, epochs=config.EPOCHS, verbose=1, validation_split=0, batch_size = 32)

In [18]:
playground = PlayGround()

In [19]:
playground.plot_figure()

FigureWidget({
    'data': [{'type': 'scatter', 'uid': 'a13340b1-6f4f-42c6-861e-1a1a951b99b3'}],
    'layout':…

In [None]:
try:
    playground.self_play()
except RejectedAction as e:
    print(e)

*********************************