In [1]:
from connect_four import MCTS

In [2]:
mcts = MCTS()
mcts.run(maxiter=78980)

In [None]:
def single_game():
    start_state = np.array(
        [
            [0, 1, -1, -1, 1, -1, 0],
            [0, 1, 1, -1, -1, 1, 0],
            [0, -1, -1, -1, 1, -1, 0],
            [0, 1, 1, 1, -1, 1, 0],
            [0, 1, -1, -1, 1, -1, 0],
            [-1, 1, 1, 1, -1, -1, 1],
        ]
    )

    display_actions = [
        {'index': 0, 'val': 'X'}, 
        {'index': 1, 'val': 'X'}, 
        {'index': 2, 'val': 'X'}, 
        {'index': 3, 'val': 'X'}, 
        {'index': 4, 'val': 'X'}, 
        {'index': 5, 'val': 'X'}, 
        {'index': 6, 'val': 'X'}
    ]

    game = Connect4()
    game.state = start_state
    mcts = MCTS(game)
    human_intervention = False

    done = False
    while not done:
        best_action = mcts.get_best_action(maxiter=2048)
        for action in best_action[1]:
            display_actions[action.action]['val'] = round(action.qval, 2)
        display(game.state, display_actions)
        act = best_action[0].action
        if human_intervention:
            try:
                act = int(input("Choose a column (0-6): "))
            except ValueError:
                print("Invalid input. Please enter a number between 0 and 6.")
            
        reward, done = game.step(act)
        if done:
            if reward > 0:
                print('Win')
            elif reward < 0:
                print('Lose')
            else:
                print('Draw')
            display(game.state, display_actions)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

def display(board, values):
    fig, ax = plt.subplots(figsize=[7, 6])
    cmap = mcolors.ListedColormap(['white', 'green', 'blue'])
    norm = mcolors.BoundaryNorm([-1, 0.5, 1.5, 2.5], cmap.N)
    ax.matshow(board, cmap=cmap, norm=norm)
    
    for x in range(8):
        ax.plot([x - .5, x - .5], [-.5, 5.5], 'k')
    for y in range(7):
        ax.plot([-.5, 6.5], [y - .5, y - .5], 'k')
        
    for v in values:
        ax.text(v['index'], -1, str(v['val']), ha='center', va='center', fontsize=20, color='black')
        
    ax.set_axis_off()
    plt.show()
    plt.close()