Skip to content

Commit

Permalink
pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
daizutabi committed Apr 28, 2020
1 parent 29febe6 commit 10dd471
Show file tree
Hide file tree
Showing 13 changed files with 123 additions and 89 deletions.
7 changes: 1 addition & 6 deletions ivory/callbacks/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,11 @@ def log_metrics(self, run_id, metrics, step=0):

def set_tags(self, run_id, tags):
for key, value in tags.items():
self.client.set_tag(run_id, key, value)
self.client.set_tag(run_id, key, to_str(value))

def set_parent_run_id(self, run_id, parent_run_id):
self.client.set_tag(run_id, MLFLOW_PARENT_RUN_ID, parent_run_id)

def create_tracker(self):
from ivory.core.tracker import Tracker

return Tracker(self.tracking_uri)


def to_str(value):
if isinstance(value, (list, tuple)):
Expand Down
27 changes: 13 additions & 14 deletions ivory/core/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy

from ivory import utils
from ivory.core import instance
from ivory.core import default, instance
from ivory.core.dict import Dict


Expand Down Expand Up @@ -38,30 +38,29 @@ def experiment_id(self):
def experiment_name(self):
return self.params["experiment"]["name"]

def create_params(self, args=None, **kwargs):
def create_params(self, args=None, name="run", **kwargs):
params = copy.deepcopy(self.params)
update, args = utils.create_update(params["run"], args, **kwargs)
utils.update_dict(params["run"], update)
if name not in params:
params.update(default.get(name))
update, args = utils.create_update(params[name], args, **kwargs)
utils.update_dict(params[name], update)
return params, args

def create_run(self, args=None, class_name="Run", **kwargs):
params, args = self.create_params(args, **kwargs)
name = class_name.lower()
if name not in params:
params[name] = {}
def create_run(self, args=None, name="run", **kwargs):
params, args = self.create_params(args, name, **kwargs)
if self.tracker:
run_name = self.tracker.create_run_name(self.experiment_id, class_name)
run_name = self.tracker.create_run_name(self.experiment_id, name)
params[name]["name"] = run_name
run = instance.create_base_instance(params, name, self.source_name)
if self.tracker:
run.set_tracker(self.tracker)
args = {arg: utils.get_value(run.params["run"], arg) for arg in args}
args = {arg: utils.get_value(run.params[name], arg) for arg in args}
run.tracking.log_params(run.id, args)
return run

def create_instance(self, name: str, args=None, **kwargs):
params, _ = self.create_params(args, **kwargs)
return instance.create_instance(params["run"], name)
def create_instance(self, instance_name: str, args=None, name="run", **kwargs):
params, _ = self.create_params(args, name, **kwargs)
return instance.create_instance(params[name], instance_name)


class Callback:
Expand Down
35 changes: 21 additions & 14 deletions ivory/core/client.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import os
import re
import subprocess
from typing import Iterator
from typing import Any, Dict, Iterator, Tuple

import ivory.utils.data
from ivory import utils
from ivory.core import instance
from ivory.core import default, instance
from ivory.core.base import Base
from ivory.core.experiment import Experiment
from ivory.core.run import Run
from ivory.utils.tqdm import tqdm


class Client(Base):
def create_experiment(self, path: str, name: str = "") -> Experiment:
params, source_name = utils.load_params(path, self.source_name)
if "experiment" not in params:
params["experiment"] = {}
params.update(default.get("experiment"))
if "name" not in params["experiment"]:
if name:
path = ".".join([path, name])
Expand All @@ -25,24 +26,32 @@ def create_experiment(self, path: str, name: str = "") -> Experiment:
experiment.set_tracker(self.tracker)
return experiment

def search_run_ids(self, name: str = "", **query) -> Iterator[str]:
def search_run_ids(
self,
name: str = "",
parent_run_id: str = "",
parent_only: bool = False,
**query
) -> Iterator[str]:
for experiment in self.tracker.list_experiments():
if name and not re.match(name, experiment.name):
continue
yield from self.tracker.search_run_ids(experiment.experiment_id, **query)
yield from self.tracker.search_run_ids(
experiment.experiment_id, parent_run_id, parent_only, **query
)

def load_params(self, run_id):
def load_params(self, run_id: str) -> Dict[str, Any]:
return self.tracker.load_params(run_id)

def load_run(self, run_id, mode="test"):
def load_run(self, run_id: str, mode: str = "test") -> Run:
run = self.tracker.load_run(run_id, mode)
run.set_tracker(self.tracker)
return run

def load_instance(self, run_id, name, mode="test"):
def load_instance(self, run_id: str, name: str, mode: str = "test") -> Any:
return self.tracker.load_instance(run_id, name, mode)

def load_results(self, run_ids, verbose=True):
def load_results(self, run_ids: str, verbose: bool = True) -> Tuple:
if verbose:
run_ids = tqdm(list(run_ids))
it = (self.load_instance(run_id, "results", "test") for run_id in run_ids)
Expand All @@ -60,12 +69,10 @@ def create_client(path="client", directory=".", tracker=True) -> Client:
source_name = utils.normpath(path, directory)
if os.path.exists(source_name):
params, _ = utils.load_params(source_name)
if not tracker and "tracker" in params["client"]:
params["client"].pop("tracker")
else:
params = {"client": {}}
if tracker:
params["client"]["tracker"] = {}
params = default.get("client")
if not tracker and "tracker" in params["client"]:
params["client"].pop("tracker")
with utils.chdir(source_name):
client = instance.create_base_instance(params, "client", source_name)
return client
18 changes: 15 additions & 3 deletions ivory/core/default.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
DEFAULT_CLASS = {}
import copy
from typing import Any, Dict

DEFAULTS: Dict[str, Any] = {}

DEFAULTS["client"] = {"client": {"tracker": {}}}
DEFAULTS["experiment"] = {"experiment": {}}
DEFAULTS["task"] = {"task": {}}


def get(name: str):
return copy.deepcopy(DEFAULTS[name])


DEFAULT_CLASS: Dict[str, Any] = {}

DEFAULT_CLASS["core"] = {
"client": "ivory.core.client.Client",
Expand Down Expand Up @@ -43,6 +57,4 @@ def update_class(params, library="core"):
params[key][kind] = DEFAULT_CLASS[library][key]
elif key in DEFAULT_CLASS["core"]:
params[key][kind] = DEFAULT_CLASS["core"][key]
else:
raise ValueError(f"Can't find class for {key}.")
update_class(value, library)
6 changes: 3 additions & 3 deletions ivory/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ def set_tracker(self, tracker):
self["tracker"] = tracker

def create_task(self):
return self.create_run(class_name="Task")
return self.create_run(name="task")

def create_study(self, run_number: int = 0):
return self.create_run(class_name="Study", run_number=run_number)
def create_study(self):
return self.create_run(name="study")

def update_params(self, **default):
self.tracker.update_params(self.id, **default)
9 changes: 4 additions & 5 deletions ivory/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,9 @@ def literal_eval(value):
return value


def product(args=None, **kwargs):
params = parse_args(args, **kwargs)
def product(params):
for values in itertools.product(*params.values()):
update = {}
args = {}
for name, value in zip(params.keys(), values):
update[name] = value
yield update
args[name] = value
yield args
6 changes: 3 additions & 3 deletions ivory/core/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def create_run(self, args):
return run

def product(self, args=None, repeat=1, **kwargs):
for args in tqdm(list(parser.product(args, **kwargs))):
run = self.create_run(args)
yield run
params = parser.parse_args(args, **kwargs)
for args in tqdm(list(parser.product(params))):
yield self.create_run(args)


class Study(Task):
Expand Down
41 changes: 34 additions & 7 deletions ivory/core/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,40 @@ def create_tracking(self):
def list_experiments(self, view_type=None):
return self.client.list_experiments(view_type)

def list_run_ids(self, experiment_id: str):
for run_info in self.client.list_run_infos(experiment_id):
yield run_info.run_id

def search_run_ids(self, experiment_id: str, **query):
for run_id in self.list_run_ids(experiment_id):
def list_run_ids(self, experiment_id: str, parent_run_id: str = ""):
if parent_run_id:
yield from self.list_nested_run_ids(experiment_id, parent_run_id)
else:
for run_info in self.client.list_run_infos(experiment_id):
yield run_info.run_id

def list_nested_run_ids(self, experiment_id: str, parent_run_id: str):
filter_string = f"tags.mlflow.parentRunId={parent_run_id!r}"
for run in self.client.search_runs(experiment_id, filter_string):
yield run.info.run_id

def list_parent_run_ids(self, experiment_id: str):
for run in self.client.search_runs(experiment_id):
if "mlflow.parentRunId" not in run.data.tags:
yield run.info.run_id

def search_run_ids(
self,
experiment_id: str,
parent_run_id: str = "",
parent_only: bool = False,
**query,
):
if parent_only:
run_ids = self.list_parent_run_ids(experiment_id)
else:
run_ids = self.list_run_ids(experiment_id, parent_run_id)
for run_id in run_ids:
if query:
params = self.load_params(run_id)
try:
params = self.load_params(run_id)
except FileNotFoundError:
continue
if utils.match(params, **query):
yield run_id
else:
Expand All @@ -67,6 +93,7 @@ def get_run_number(self, experiment_id: str, prefix: str):
return run_number

def create_run_name(self, experiment_id: str, prefix: str):
prefix = prefix[0].upper() + prefix[1:]
run_number = self.get_run_number(experiment_id, prefix)
return f"{prefix}#{run_number + 1:03d}"

Expand Down
35 changes: 14 additions & 21 deletions ivory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,37 +33,30 @@ def cli():
@click.option("-r", "--repeat", default=1, help="Number of repeatation.")
@click.option("--notest", is_flag=True, help="Skip test after training.")
@click.option("--notrack", is_flag=True, help="No tracking mode.")
@click.option("-m", "--message", default="", help="Message for tracking.")
def run(path, args, repeat, notest, notrack, message):
def run(path, args, repeat, notest, notrack):
client = ivory.create_client(tracker=not notrack)
experiment = client.create_experiment(path)
for run in experiment.start(args, repeat=repeat, message=message):
task = experiment.create_task()
for run in task.product(args, repeat=repeat):
run.start("train")
if not notest and not notrack:
run = experiment.load_run(run.id, "best")
run = client.load_run(run.id, "best")
try:
run.start("test")
except TestDataNotFoundError:
pass


# @cli.command(help="Optimize hyper parameters.")
# @click.argument("path", callback=normpath)
# @click.argument("name")
# @click.argument("args", nargs=-1)
# @click.option("-m", "--message", default="", help="Message for tracking.")
# def optimize(path, name, args, message):
# client = create_client(path)
# client.optimize(name, args, message=message)
#
#
# @cli.command(help="Search runs.")
# @click.argument("path", callback=normpath)
# @click.argument("args", nargs=-1)
# @click.option("-m", "--message", default="", help="Message for tracking.")
# def search(path, args, message):
# pass
#
@cli.command(help="Optimize hyper parameters.")
@click.argument("path")
@click.argument("name")
# @click.argument("n_trials", nargs=-1)
@click.option("--notrack", is_flag=True, help="No tracking mode.")
def optimize(path, name, notrack):
client = ivory.create_client(tracker=not notrack)
experiment = client.create_experiment(path)
study = experiment.create_study()
study.optimize(name, n_trials=3)


@cli.command(help="Start tracking UI.")
Expand Down
8 changes: 5 additions & 3 deletions ivory/utils/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def update_dict(org: Dict[str, Any], update: Dict[str, Any]) -> None:
x[k] = value
elif isinstance(x[k], str) and x[k].startswith("$"):
x[k] = value
elif type(x[k]) is not type(value):
elif type(x[k]) is not type(value) and x[k] is not None:
raise ValueError(f"different type: {type(x[k])} != {type(value)}")
else:
if isinstance(value, dict):
if isinstance(x[k], dict):
x[k].update(value)
else:
x[k] = value
Expand Down Expand Up @@ -142,11 +142,13 @@ def create_update(params, args=None, **kwargs):
args = {}
args.update(kwargs)
args = colon_to_list(args)
args_ = {}
update = {}
for name, value in args.items():
for fullname in get_fullnames(params, name):
update[fullname] = value
return update, args
args_[name] = value
return update, args_


def get_value(params, name):
Expand Down
3 changes: 2 additions & 1 deletion ivory/utils/tqdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

with warnings.catch_warnings():
warnings.simplefilter("ignore")
from tqdm.autonotebook import tqdm
# from tqdm.autonotebook import tqdm
from tqdm import tqdm

__all__ = ["tqdm"]
8 changes: 0 additions & 8 deletions tests/core/test_core_default.py

This file was deleted.

9 changes: 8 additions & 1 deletion tests/core/test_core_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
def test_experiment_run_str(task):
def test_task(client, task):
for k, run in enumerate(task.product(["fold=1-2"], max_epochs="3,4")):
assert run.dataloaders.fold == [1, 1, 2, 2][k]
assert run.trainer.max_epochs == [3, 4, 3, 4][k]
if k != 0:
run.start()

f = client.search_run_ids
assert len(list(f("example", parent_run_id=task.id))) == 4
assert len(list(f("example", parent_run_id=task.id, fold=1))) == 1
assert run.id not in list(client.search_run_ids(parent_only=True))

0 comments on commit 10dd471

Please sign in to comment.