In [30]:
%autosave 0

Autosave disabled


In [31]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [40]:
from typing import Optional
from games.connect4 import Connect4
from games.game import Game
from injectors import get_mcts, get_network, get_replay_buffer, get_trainer
from train import self_play
import playing

In [19]:
import cProfile
import pstats

profiling = True

In [27]:
# device = "cpu"
device = "cuda"

network = get_network(Connect4)

replay_buffer = get_replay_buffer(Connect4)
mcts = get_mcts(network)

trainer = get_trainer(
    device,
    network,
    replay_buffer,
)

game = Connect4()


def mcts_policy(game: Game):
    return mcts.search(game)

In [15]:
%%time
pr: Optional[cProfile.Profile] = None
if profiling:
    with cProfile.Profile() as pr:
        self_play(Connect4, mcts, replay_buffer, num_games=10)
else:
    self_play(Connect4, mcts, replay_buffer, num_games=1)

Games played:   0%|          | 0/1 [00:00<?, ?it/s]

CPU times: user 18.7 s, sys: 971 μs, total: 18.7 s
Wall time: 18.8 s


In [8]:
# replay_buffer.save("10games_played")

In [21]:
replay_buffer.load("10games_played.npz")

In [22]:
if pr is not None:
    stats = pstats.Stats(pr)
    stats.sort_stats("cumtime").print_stats(30)
    pr.dump_stats("self_play.prof")

In [26]:
trainer.train(batch_size=1)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [39]:
final_reward = playing.play_game(game, mcts_policy_fn=mcts_policy)
print(f"Game result: {'You won' if final_reward == 1 else 'AI won'}")

0 0 0 0 0 0 0
0 0 0 0 0 0 0
0 0 0 0 0 0 0
0 0 0 0 0 0 0
0 0 0 0 0 0 0
0 0 0 0 0 0 0
-------
0 1 2 3 4 5 6
Legal actions: [0, 1, 2, 3, 4, 5, 6]
Enter your action: 1
0 0 0 0 0 0 0
0 0 0 0 0 0 0
0 0 0 0 0 0 0
0 0 0 0 0 0 0
0 0 0 0 0 0 0
0 1 -1 0 0 0 0
-------
0 1 2 3 4 5 6
Legal actions: [0, 1, 2, 3, 4, 5, 6]
Enter your action: 1
0 0 0 0 0 0 0
0 0 0 0 0 0 0
0 0 0 0 0 0 0
0 0 0 0 0 0 0
0 1 0 0 0 0 0
0 1 -1 0 0 -1 0
-------
0 1 2 3 4 5 6
Legal actions: [0, 1, 2, 3, 4, 5, 6]
Enter your action: 1
0 0 0 0 0 0 0
0 0 0 0 0 0 0
0 0 0 0 0 0 0
0 1 0 0 0 0 0
0 1 0 0 0 0 0
0 1 -1 0 0 -1 -1
-------
0 1 2 3 4 5 6
Legal actions: [0, 1, 2, 3, 4, 5, 6]
Enter your action: 1
0 0 0 0 0 0 0
0 0 0 0 0 0 0
0 1 0 0 0 0 0
0 1 0 0 0 0 0
0 1 0 0 0 0 0
0 1 -1 0 0 -1 -1
-------
0 1 2 3 4 5 6
Game result: You won


In [None]:
final_reward = playing.play_game(game, mcts_policy_fn=mcts_policy)
print(f"Game result: {'You won' if final_reward == 1 else 'AI won'}")

⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
―――――――――――――
0 1 2 3 4 5 6
Legal actions: [0, 1, 2, 3, 4, 5, 6]
Enter your action: 1
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ 🔴 ⚪ ⚪ 🟡 ⚪ ⚪
―――――――――――――
0 1 2 3 4 5 6
Legal actions: [0, 1, 2, 3, 4, 5, 6]
Enter your action: 2
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ 🟡 ⚪ ⚪ ⚪ ⚪ ⚪
⚪ 🔴 🔴 ⚪ 🟡 ⚪ ⚪
―――――――――――――
0 1 2 3 4 5 6
Legal actions: [0, 1, 2, 3, 4, 5, 6]
Enter your action: 3
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ ⚪ ⚪ ⚪ ⚪ ⚪ ⚪
⚪ 🟡 ⚪ ⚪ ⚪ ⚪ ⚪
🟡 🔴 🔴 🔴 🟡 ⚪ ⚪
―――――――――――――
0 1 2 3 4 5 6
Legal actions: [0, 1, 2, 3, 4, 5, 6]
