In [None]:
%autosave 0

In [None]:
%load_ext autoreload
%autoreload 2

In [1]:

from sympy import im
import torch
from injectors import (
    get_inferer_factory,
    get_mcts_factory,
    get_network,
    get_replay_buffer,
    get_trainer,
)
from typing import Optional
from games.connect4 import Connect4
from games.game import Game
from network import AlphaZeroNetwork
from train import self_play_and_train_loop
import playing

In [1]:
import sys

sys.path.append("../build/training/")
sys.path.append(".")


import importlib.util
import sys

spec = importlib.util.spec_from_file_location(
    "self_play_bind", "./self_play_bind.cpython-313-x86_64-linux-gnu.so"
)
module = importlib.util.module_from_spec(spec)
sys.modules["self_play_bind"] = module
spec.loader.exec_module(module)

import self_play_bind

ImportError: /home/piotrek/miniconda3/envs/mpum-big-project/lib/python3.12/site-packages/zmq/backend/cython/../../../../.././libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /home/piotrek/Studia/2-rok/4-sem/MPUM/Project/AlphaZero/python/self_play_bind.cpython-313-x86_64-linux-gnu.so)

In [None]:
import cProfile
import pstats

profiling = True

In [2]:
# device = "cpu"
device = torch.device("cuda")

network = get_network(Connect4)
network.save_az_network("AZNetwork")
inferer_factory = get_inferer_factory(AlphaZeroNetwork, "AZNetwork", device)
replay_buffer = get_replay_buffer(Connect4)


game = Connect4()

## Profiling

### Self play profiling

In [None]:
%%time
pr: Optional[cProfile.Profile] = None
if profiling:
    with cProfile.Profile() as pr:
        self_play(
            Connect4, inferer_factory, replay_buffer, get_mcts_factory, num_games=1
        )

In [None]:
if pr is not None:
    stats = pstats.Stats(pr)
    stats.sort_stats("cumtime").print_stats(30)
    pr.dump_stats("self_play.prof")

### Trainer profiling

In [None]:
%%time
replay_buffer.load("10games_played.npz")

network = AlphaZeroNetwork.load_az_network("AZNetwork", device)
print(next(network.parameters()).device)
if profiling:
    with cProfile.Profile() as pr:
        trainer = get_trainer(
            network,
            device,
            replay_buffer,
        )

        network.train()
        trainer.train(batch_size=1)

In [None]:
if pr is not None:
    stats = pstats.Stats(pr)
    stats.sort_stats("cumtime").print_stats(30)
    pr.dump_stats("train.prof")

In [None]:
replay_buffer.save("10games_played")

## Self play and training loop

In [None]:
# TODO: write this
self_play_and_train_loop(
    AlphaZeroNetwork,
    "AZNetwork",
    network_device=device,
    game=Connect4,
    load_replay_buffer=get_replay_buffer,
    trainer_factory=get_trainer,
    inferer_provider_getter=get_inferer_factory,
    mcts_factory_getter=get_mcts_factory,
    loop_iterations=1,
    games_in_each_iteration=1,
    batch_size=1,
)

## Playing the Game

In [3]:
network.eval()

mcts_fac = get_mcts_factory(inferer_factory)
mcts = mcts_fac.get_mcts()


def mcts_policy(game: Game):
    return mcts.search(game)


print()




In [4]:
final_reward = playing.play_game(game, mcts_policy_fn=mcts_policy)
print(f"Game result: {'You won' if final_reward == 1 else 'AI won'}")

⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
―――――――――――――
0 1 2 3 4 5 6
Legal actions: [0, 1, 2, 3, 4, 5, 6]
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ 🔴 🟡 ⚪
―――――――――――――
0 1 2 3 4 5 6
Legal actions: [0, 1, 2, 3, 4, 5, 6]
Invalid action, try again.
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ 🟡 ⚪ 🔴 🔴 🟡 ⚪
―――――――――――――
0 1 2 3 4 5 6
Legal actions: [0, 1, 2, 3, 4, 5, 6]
Invalid action, try again.
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ 🟡 ⚪
⚪ 🟡 🔴 🔴 🔴 🟡 ⚪
―――――――――――――
0 1 2 3 4 5 6
Legal actions: [0, 1, 2, 3, 4, 5, 6]
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ 🟡 🔴 🟡 ⚪
⚪ 🟡 🔴 🔴 🔴 🟡 ⚪
―――――――――――――
0 1 2 3 4 5 6
Legal actions: [0, 1, 2, 3, 4, 5, 6]
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ 🟡 ⚪ ⚪ ⚪
⚪ ⚪ 🔴 🟡 🔴 🟡 ⚪
⚪ 🟡 🔴 🔴 🔴 🟡 ⚪
―――――――――――――
0 1 2 3 4 5 6
Legal actions: [0, 1, 2, 3, 4, 5, 6]
Invalid action, try again.
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪

In [None]:
final_reward = playing.play_game(game, mcts_policy_fn=mcts_policy)
print(f"Game result: {'You won' if final_reward == 1 else 'AI won'}")