Skip to content

Commit

Permalink
Support client settings (#17)
Browse files Browse the repository at this point in the history
* Allow user to pass clientconfig to tuning-client

* Fix returned value

* Adjust benchmark to include client settings
  • Loading branch information
kiudee committed Feb 6, 2020
1 parent 43b42f6 commit c2f783f
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 16 deletions.
14 changes: 14 additions & 0 deletions examples/client.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"engine1": {
"Threads": 2,
"Backend": "cudnn-auto",
"BackendOptions": "",
"NNCacheSize": 200000,
"MinibatchSize": 256,
"MaxPrefetch": 32
},
"engine2": {
"Threads": 4
},
"SyzygyPath": "path/to/tb/wdl"
}
5 changes: 3 additions & 2 deletions tune/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ def cli():
@click.option("--verbose", "-v", is_flag=True, default=False, help="Turn on debug output.")
@click.option("--logfile", default=None, help="Path to where the log is saved to.")
@click.option("--terminate-after", default=0, help="Terminate the client after x minutes.")
@click.option("--clientconfig", default=None, help="Path to the client configuration file.")
@click.argument("dbconfig")
def run_client(verbose, logfile, terminate_after, dbconfig):
def run_client(verbose, logfile, terminate_after, clientconfig, dbconfig):
""" Run the client to generate games for distributed tuning.
In order to connect to the database you need to provide a valid DBCONFIG
Expand All @@ -27,7 +28,7 @@ def run_client(verbose, logfile, terminate_after, dbconfig):
logging.basicConfig(
level=log_level, filename=logfile, format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
tc = TuningClient(dbconfig_path=dbconfig, terminate_after=terminate_after)
tc = TuningClient(dbconfig_path=dbconfig, terminate_after=terminate_after, clientconfig=clientconfig)
tc.run()


Expand Down
94 changes: 83 additions & 11 deletions tune/db_workers/tuning_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from time import sleep, time
from psycopg2.extras import DictCursor

from ..io import InitStrings
from .utils import parse_timecontrol, MatchResult, TimeControl

CLIENT_VERSION = 1
Expand All @@ -18,7 +19,7 @@


class TuningClient(object):
def __init__(self, dbconfig_path, terminate_after=0, **kwargs):
def __init__(self, dbconfig_path, terminate_after=0, clientconfig=None, **kwargs):
self.end_time = None
if terminate_after != 0:
start_time = time()
Expand All @@ -34,7 +35,18 @@ def __init__(self, dbconfig_path, terminate_after=0, **kwargs):
self.logger.debug(f"Reading DB config:\n{config}")
self.connect_params = json.loads(config)
else:
raise ValueError("No config file found at provided path")
raise ValueError(f"No config file found at provided path:\n{dbconfig_path}")

self.client_config = None
if clientconfig is not None:
if os.path.isfile(clientconfig):
with open(clientconfig, "r") as ccfile:
config = ccfile.read().replace("\n", "")
self.client_config = json.loads(config)
else:
raise ValueError(
f"Client configuration file not found:\n{clientconfig}"
)

def interrupt_handler(self, sig, frame):
if self.interrupt_pressed:
Expand Down Expand Up @@ -78,10 +90,12 @@ def run_experiment(self, time_control, cutechess_options):
"-repeat",
"-games",
"2",
# "-tb", "/path/to/tb", # TODO: Support tablebases
"-pgnout",
"out.pgn",
]
if "syzygy_path" in cutechess_options:
st.insert(-2, "-tb")
st.insert(-2, cutechess_options["syzygy_path"])
out = subprocess.run(st, capture_output=True)
return self.parse_experiment(out)

Expand All @@ -106,18 +120,49 @@ def parse_experiment(self, results):
w, l, d = [float(x) for x in re.findall("[0-9]", result)]
return MatchResult(wins=w, losses=l, draws=d)

def run_benchmark(self):
def run_benchmark(self, config):
def uci_to_cl(k, v):
mapping = {"Threads": "--threads",
"NNCacheSize": "--nncache",
"Backend": "--backend",
"BackendOptions": "--backend-opts",
"MinibatchSize": "--minibatch-size",
"MaxPrefetch": "--max-prefetch"}
if k in mapping:
return f"{mapping[k]}={v}"
return None

def cl_arguments(init_strings):
cl_args = []
for k, v in init_strings.items():
arg = uci_to_cl(k, v)
if arg is not None:
cl_args.append(arg)
return cl_args

self.logger.debug(f"Before benchmark engine 1:\n{config['engine'][0]['initStrings']}")
args = cl_arguments(InitStrings(config["engine"][0]["initStrings"]))
path = os.path.join(os.path.curdir, "lc0")
out = subprocess.run([path, "benchmark"], capture_output=True)
out = subprocess.run([path, "benchmark"] + args, capture_output=True)
s = out.stdout.decode("utf-8")
result = float(re.findall(r"([0-9]+\.[0-9]+)\snodes per second", s)[0])
try:
result = float(re.findall(r"([0-9]+\.[0-9]+)\snodes per second", s)[0])
except IndexError:
self.logger.error(f"Error while parsing engine1 benchmark:\n{s}")
sys.exit(1)
self.lc0_benchmark = result

self.logger.debug(f"Before benchmark engine 2:\n{config['engine'][1]['initStrings']}")
num_threads = InitStrings(config["engine"][1]["initStrings"])["Threads"]
path = os.path.join(os.path.curdir, "sf")
out = subprocess.run([path, "bench"], capture_output=True)
out = subprocess.run([path, "bench", "16", str(int(num_threads))], capture_output=True)
# Stockfish outputs results as stderr:
s = out.stderr.decode("utf-8")
result = float(re.findall(r"Nodes/second\s+:\s([0-9]+)", s)[0])
try:
result = float(re.findall(r"Nodes/second\s+:\s([0-9]+)", s)[0])
except IndexError:
self.logger.error(f"Error while parsing engine2 benchmark:\n{s}")
sys.exit(1)
self.sf_benchmark = result

def adjust_time_control(self, time_control, lc0_nodes, sf_nodes):
Expand Down Expand Up @@ -169,10 +214,34 @@ def pick_job(self, jobs, mix=0.25):
self.logger.debug(f"Picked job {rand_i} (job_id={jobs[rand_i]['job_id']})")
return jobs[rand_i]

def incorporate_config(self, job):
admissible_uci = [
"Threads",
"Backend",
"BackendOptions",
"NNCacheSize",
"MinibatchSize",
"MaxPrefetch",
]
engines = job["config"]["engine"]
for i, e in enumerate(engines):
engine_str = f"engine{i+1}"
if engine_str in self.client_config:
init_strings = InitStrings(e["initStrings"])
for k, v in self.client_config[engine_str].items():
if k in admissible_uci:
init_strings[k] = v
if "SyzygyPath" in self.client_config:
path = self.client_config["SyzygyPath"]
for e in engines:
init_strings = InitStrings(e["initStrings"])
init_strings["SyzygyPath"] = path
job["config"]["cutechess"]["syzygy_path"] = path

def run(self):
while True:
if self.interrupt_pressed:
self.logger.info('Shutting down after receiving shutdown signal.')
self.logger.info("Shutting down after receiving shutdown signal.")
sys.exit(0)
if self.end_time is not None and self.end_time < time():
self.logger.info("Shutdown timer triggered. Closing")
Expand Down Expand Up @@ -205,6 +274,8 @@ def run(self):

# 2. Set up experiment
# a) write engines.json
if self.client_config is not None:
self.incorporate_config(job)
job_id = job["job_id"]
config = job["config"]
self.logger.debug(f"Received config:\n{config}")
Expand All @@ -216,9 +287,10 @@ def run(self):
# b) Adjust time control:
if self.lc0_benchmark is None:
self.logger.info(
"Running initial nodes/second benchmark to calibrate time controls..."
"Running initial nodes/second benchmark to calibrate time controls."
"Ensure that your pc is idle to get a good reading."
)
self.run_benchmark()
self.run_benchmark(config)
self.logger.info(
f"Benchmark complete. Results: lc0: {self.lc0_benchmark} nps, sf: {self.sf_benchmark} nps"
)
Expand Down
14 changes: 11 additions & 3 deletions tune/io.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import re
import sys
from collections.abc import MutableMapping


# TODO: Backup file to restore it, should there be an error
def uci_tuple(uci_string):
name, value = re.findall(r"name (\w+) value (-?[0-9.]+)", uci_string)[0]
value = float(value)
return name, value
try:
name, value = re.findall(r"name (\w+) value (-?[0-9.]+|\w*)", uci_string)[0]
except IndexError:
print(f"Error parsing UCI tuples:\n{uci_string}")
sys.exit(1)
try:
tmp = float(value)
except ValueError:
tmp = value
return name, tmp


def set_option(name, value):
Expand Down

0 comments on commit c2f783f

Please sign in to comment.