<a href="https://colab.research.google.com/github/lucasBertola/AlphaZero/blob/main/Play_again_alpha_zero.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

To load the model faster, it is recommended to activate the GPU in Google Colab. To do this, follow these steps:

1. Click on `Runtime` or `Execution` in the top menu.
2. Select `Change runtime type` from the dropdown menu.
3. In the `Hardware accelerator` section, choose `GPU` from the dropdown menu.
4. Click `Save` to apply the changes.

After activating the GPU, the model should load and run faster.

In [None]:
!git clone  https://github.com/lucasBertola/AlphaZero.git
!pip install -r AlphaZero/requirement.txt
import numpy as np
import os
import time
import io
import sys
import random
from PIL import Image
from IPython.display import display
from IPython.display import display, HTML
os.chdir("AlphaZero") if os.path.isdir("AlphaZero") else None

from connect_four_gymnasium import ConnectFourEnv
from connect_four_gymnasium.players import ConsolePlayer
from alphaFour import AlphaFour
import base64

class MCTSPlayer:
    def __init__(self, ai: AlphaFour, name="MCTS", deterministic=True):
        self.ai = ai
        self.name = name
        self.deterministic = deterministic

    def play_single(self, observation):
        env = ConnectFourEnv()
        env.reset()
        env.board = observation
        predicted_policies, _ = self.ai.mcts_parallel([env])
        return np.argmax(predicted_policies, axis=1)

    def play(self, observations):
        if isinstance(observations, list):
            envs = [ConnectFourEnv() for _ in range(len(observations))]
            for i, env in enumerate(envs):
                env.reset()
                env.board = observations[i]
            predicted_policies, values = self.ai.mcts_parallel(envs)
        else:
            env = ConnectFourEnv()
            env.reset()
            env.board = observations
            predicted_policies, values = self.ai.mcts_parallel([env])

        print(f'IA: I think I have a {round(((values[0][0] + 1) / 2) * 100)}% chance of winning')
        return np.argmax(predicted_policies, axis=1)

    def getName(self):
        return self.name

    def is_deterministic(self):
        return self.deterministic

    def get_elo(self):
        return None


def find_latest_model():
    for i in reversed(range(500)):
        path = f'models/model_{i}.pt'
        if os.path.exists(path):
            return path, i

    return None, None


def display_rgb_array(rgb_array):
    img = Image.fromarray(rgb_array, 'RGB')
    img.show()

def display_centered_image(rgb_array):
    img = Image.fromarray(rgb_array, 'RGB')
    img_data = io.BytesIO()
    img.save(img_data, format='PNG')
    img_data = img_data.getvalue()
    img_width, img_height = img.size
    display(HTML(f'<div style="display: flex; justify-content: center;"><img src="data:image/png;base64,{base64.b64encode(img_data).decode()}" width="{img_width}" height="{img_height}" /></div>'))



def main():
    latest_model_path, generation = find_latest_model()

    if latest_model_path is not None:
        print(f'Loading model generation {generation}')
        ai_instance = AlphaFour(latest_model_path,iteration=2500)
    else:
        print('Error: No model found')
        exit()

    mcts_player = MCTSPlayer(ai_instance, 'MCTS')
    human_player = ConsolePlayer()
    env = ConnectFourEnv(render_mode="rgb_array",main_player_name="You")

    observation, _ = env.reset()
    display_centered_image(env.render())

    players = [human_player, mcts_player]
    random.shuffle(players)

    for _ in range(5000):
        for player in players:
            action = player.play(observation)
            observation, rewards, done, truncated, _ = env.step(action)
            display_centered_image(env.render())
            if truncated or done:
                print('finish')
                return

if __name__ == '__main__':
    main()