In [8]:
from dataclasses import replace

import cpp_game
from ai.ai import *
from ai.tree_search import *
from game.settings import *

In [9]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
settings = get_preset(DEFAULT_PRESET)
settings = replace(settings, min_difficulty=7, max_difficulty=7)
cpp_settings = settings.to_cpp()
ai = get_ai(settings)
engine = cpp_game.Engine(cpp_settings)

In [4]:
from ai.rollout import do_batch_rollout_cpp

cpp_settings = settings.to_cpp()
batch_rollout = cpp_game.BatchRollout(cpp_settings, 200)

start = time.time()
td = do_batch_rollout_cpp(
    batch_rollout, pv_model=ai.pv_model, argmax=True, record_cpp_time=True
)
print(f"elapsed={time.time() - start:.3f}")
print(f"win_rate={td['win'].float().mean().item():.3f}")
ai.pv_model.start_single_step()

C++ time: 0.054s
elapsed=0.215
win_rate=0.465


In [43]:
ts_settings_li = [
    TreeSearchSettings(
        num_iters=100,
        seed=42,
    ),
    TreeSearchSettings(
        num_iters=100,
        root_noise=False,
        seed=42,
    ),
    TreeSearchSettings(
        num_iters=100,
        skip_thresh=0.9,
        seed=42,
    ),
    TreeSearchSettings(
        num_iters=100,
        num_parallel=10,
        seed=42,
    ),
    TreeSearchSettings(
        num_iters=100,
        num_parallel=30,
        seed=42,
    ),
]
num_rollouts = 50
engines = [cpp_game.Engine(cpp_settings) for _ in range(num_rollouts)]
total_time = 0.0
wins_li = []

for ts_settings in ts_settings_li:
    ai = get_ai(settings, ts_settings, num_rollouts)
    seeds = list(range(42, 42 + num_rollouts))

    start_time = time.time()
    wins = batch_rollout(engines, ai, seeds=seeds, use_tree_search=True)
    total_time += time.time() - start_time
    wins_li.append(wins)

wins = np.array(wins_li).astype(np.float32)

from scipy.stats import ttest_rel

win_rates = np.mean(wins, axis=1).tolist()
print(win_rates)
tstats = [ttest_rel(wins[0], x).statistic for x in wins[1:]]
print(tstats)

[0.7200000286102295, 0.6200000047683716, 0.6600000262260437, 0.6399999856948853, 0.6200000047683716]
[np.float64(1.9414507513241666), np.float64(1.3527288841185778), np.float64(1.6614942244436857), np.float64(1.9414507513241666)]
