Skip to content

Commit

Permalink
Change signature for client.load_results
Browse files Browse the repository at this point in the history
  • Loading branch information
daizutabi committed May 15, 2020
1 parent e2607e2 commit 0ad9dd2
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 71 deletions.
31 changes: 17 additions & 14 deletions ivory/callbacks/results.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""A container to store training, validation and test results. """
from typing import Iterable
from typing import Dict, Iterable

import numpy as np

from ivory.core.collections import Dict
import ivory.core.collections
from ivory.core.run import Run
from ivory.core.state import State


class Results(Dict, State):
class Results(ivory.core.collections.Dict, State):
def reset(self):
self.index = None
self.output = None
Expand Down Expand Up @@ -41,10 +41,10 @@ def result_dict(self):
return dict(index=self.index, output=self.output, target=self.target)


def concatenate(iterable: Iterable[Results], callback=None):
indexes = []
outputs = []
targets = []
def concatenate(iterable: Iterable[Results], callback=None) -> Results:
indexes: Dict[str, list] = {"val": [], "test": []}
outputs: Dict[str, list] = {"val": [], "test": []}
targets: Dict[str, list] = {"val": [], "test": []}
for results in iterable:
for mode in ["val", "test"]:
if mode not in results:
Expand All @@ -53,10 +53,13 @@ def concatenate(iterable: Iterable[Results], callback=None):
index, output, target = result["index"], result["output"], result["target"]
if callback:
index, output, target = callback(index, output, target)
indexes.append(index)
outputs.append(output)
targets.append(target)
index = np.concatenate(indexes)
output = np.concatenate(outputs)
target = np.concatenate(targets)
return index, output, target
indexes[mode].append(index)
outputs[mode].append(output)
targets[mode].append(target)
results = Results()
for mode in ["val", "test"]:
index = np.concatenate(indexes[mode])
output = np.concatenate(outputs[mode])
target = np.concatenate(targets[mode])
results[mode] = dict(index=index, output=output, target=target)
return results
92 changes: 58 additions & 34 deletions ivory/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import ivory.callbacks.results
from ivory import utils
from ivory.callbacks.results import Results
from ivory.core import default, instance
from ivory.core.base import Base, Experiment
from ivory.core.evaluator import Evaluator
from ivory.core.run import Run
from ivory.utils.tqdm import tqdm

Expand Down Expand Up @@ -55,8 +55,34 @@ def create_study(self, name: str, run_number: Optional[int] = None):
study.set(tuner=self.tuner)
return study

def create_evaluator(self, run_ids=None) -> Evaluator:
return Evaluator(self, run_ids)
def get_run_id(self, name: str, **kwargs) -> str:
run_name = list(kwargs)[0]
run_number = kwargs[run_name]
if run_number == -1:
return next(self.search_run_ids(name, run_name))
else:
experiment_id = self.tracker.get_experiment_id(name)
return self.tracker.get_run_id(experiment_id, run_name, run_number)

def get_run_ids(self, name: str, **kwargs) -> Iterator[str]:
run_name = list(kwargs)[0]
run_numbers = kwargs[run_name]
for run_number in run_numbers:
yield self.get_run_id(name, **{run_name: run_number})

def get_parent_run_id(self, name: str, **kwargs) -> str:
run_id = self.get_run_id(name, **kwargs)
return self.tracker.get_parent_run_id(run_id)

def get_nested_run_ids(self, name: str, **kwargs) -> Iterator[str]:
run_id = self.get_run_id(name, **kwargs)
return self.search_run_ids(name, parent_run_id=run_id)

def set_parent_run_id(self, name, **kwargs):
run_id = self.get_run_id(name, run=kwargs["run"])
parent = {name: number for name, number in kwargs.items() if name != "run"}
parent_run_id = self.get_run_id(name, **parent)
self.tracker.set_parent_run_id(run_id, parent_run_id)

def search_run_ids(
self,
Expand Down Expand Up @@ -96,35 +122,10 @@ def search_run_ids(
)

def search_parent_run_ids(self, name: str = "", **query) -> Iterator[str]:
return self.search_run_ids(name, parent_only=True, **query)
yield from self.search_run_ids(name, parent_only=True, **query)

def search_nested_run_ids(self, name: str = "", **query) -> Iterator[str]:
return self.search_run_ids(name, nested_only=True, **query)

def get_run_id(self, name: str, **kwargs) -> str:
run_name = list(kwargs)[0]
run_number = kwargs[run_name]
if run_number == -1:
return next(self.search_run_ids(name, run_name))
else:
experiment_id = self.tracker.get_experiment_id(name)
return self.tracker.get_run_id(experiment_id, run_name, run_number)

def get_run_ids(self, name: str, **kwargs) -> Iterator[str]:
run_name = list(kwargs)[0]
run_numbers = kwargs[run_name]
for run_number in run_numbers:
yield self.get_run_id(name, **{run_name: run_number})

def get_nested_run_ids(self, name: str, **kwargs) -> Iterator[str]:
run_id = self.get_run_id(name, **kwargs)
return self.search_run_ids(name, parent_run_id=run_id)

def set_parent_run_id(self, run_id: str, parent_run_id: str):
self.tracker.set_parent_run_id(run_id, parent_run_id)

def get_parent_run_id(self, run_id: str) -> str:
return self.tracker.get_parent_run_id(run_id)
yield from self.search_run_ids(name, nested_only=True, **query)

def set_terminated(self, name: str = ""):
for run_id in self.search_run_ids(name):
Expand All @@ -143,7 +144,20 @@ def load_run_by_name(self, name: str, mode: str = "test", **kwargs) -> Run:
def load_instance(self, run_id: str, instance_name: str, mode: str = "test") -> Any:
return self.tracker.load_instance(run_id, instance_name, mode)

def load_results(self, run_ids: Iterable[str], callback=None, verbose: bool = True):
def load_results(
self, run_ids: Iterable[str], callback=None, verbose: bool = True
) -> Results:
"""Loads results from multiple runs and concatenates them.
Args:
run_ids: Multiple run ids to load.
callback (callable): Callback function for each run. This function must take
a `(index, output, target)` and return the same signature.
verbose: If `True`, tqdm progress bar is displayed.
Returns:
A concatenated results instance.
"""
run_ids = list(run_ids)
it = (self.load_instance(run_id, "results") for run_id in run_ids)
if verbose:
Expand All @@ -164,13 +178,13 @@ def update_params(self, name: str = "", **default):
self.tracker.update_params(experiment.experiment_id, **default)

def remove_deleted_runs(self, name: str = "") -> int:
"""Remove deleted runs from local file system.
"""Removes deleted runs from a local file system.
Args:
name: experiment name pattern for filtering.
name: A regex pattern of experiment name for filtering.
Returns:
number of removed runs.
Number of removed runs.
"""
num_runs = 0
for experiment in self.tracker.list_experiments():
Expand All @@ -183,6 +197,16 @@ def remove_deleted_runs(self, name: str = "") -> int:
def create_client(
directory: str = ".", name: str = "client", tracker: bool = True
) -> Client:
"""Creates an Ivory client.
Args:
directory: A directory where a YAML config file exists.
name: The YAML config file name.
tracker: Invoke tracking or not.
Returns:
An created client.
"""
source_name = utils.path.normpath(name, directory)
if os.path.exists(source_name):
params, _ = utils.path.load_params(source_name)
Expand Down
4 changes: 4 additions & 0 deletions ivory/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def get_index(self, mode: str, fold: int) -> np.ndarray:
return index[self.fold == -1]

def get_input(self, index: Index) -> Value:
"""
Args:
index: Index
"""
return self.input[index]

def get_target(self, index: Index) -> Value:
Expand Down
14 changes: 0 additions & 14 deletions ivory/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,6 @@
import scipy.special


def concat_results(iterable):
outputs = []
targets = []
for results in iterable:
output, target = results.to_dataframe()
outputs.append(output)
targets.append(target)
output = pd.concat(outputs)
target = pd.concat(targets)
output.sort_index(inplace=True)
target.sort_index(inplace=True)
return output, target


def softmax(df):
prob = scipy.special.softmax(df.to_numpy(), axis=1)
return pd.DataFrame(prob, index=df.index)
Expand Down
26 changes: 18 additions & 8 deletions tests/core/test_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,17 @@ def test_search_run_ids(client):
assert len(list(client.search_nested_run_ids("rfr"))) == 2

assert client.get_run_id("rfr", task=0) == task.id
assert client.get_parent_run_id(a.id) == task.id
a_number = int(a.name.split("#")[1])
b_number = int(b.name.split("#")[1])
assert list(client.get_run_ids("rfr", run=[a_number, b_number])) == [a.id, b.id]
assert client.get_parent_run_id("rfr", run=a_number) == task.id
assert list(client.get_nested_run_ids("rfr", task=0)) == [b.id, a.id]
assert next(client.search_run_ids("rfr", parent_run_id=task.id)) == b.id

run = client.create_run("rfr")
run.start("train")
client.set_parent_run_id(run.id, task.id)
run_number = int(run.name.split("#")[1])
client.set_parent_run_id("rfr", run=run_number, task=0)
assert len(list(client.search_run_ids("rfr", parent_run_id=task.id))) == 3


Expand All @@ -54,11 +59,6 @@ def test_create_study(client):
assert client.create_study("rfr", -1).name == study.name


def test_create_evaluator(client):
evaluator = client.create_evaluator()
assert evaluator.client is client


def test_set_terminated(client):
assert client.set_terminated("rfr") is None

Expand All @@ -77,7 +77,7 @@ def test_load(client):
assert trainer.epoch > 0


def test_load_instance(client, run):
def test_load_instance(client):
run = client.create_run("example")
run.start("train")
results = client.load_instance(run.id, "results", "test")
Expand All @@ -88,6 +88,16 @@ def test_load_instance(client, run):
assert isinstance(model, torch.nn.Module)


def test_load_results(client):
run = client.create_run("example")
run.start("train")
run.start("test")
r = client.load_instance(run.id, "results", "test")
c = client.load_results([run.id, run.id])
assert len(c.val["index"]) == 2 * len(r.val["index"])
assert len(c.test["index"]) == 2 * len(r.test["index"])


def test_without_tracker():
client = create_client(directory="tests", tracker=False)
assert "tracker" not in client
Expand Down
14 changes: 13 additions & 1 deletion tests/utils/test_utils_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,21 @@ def test_softmax():
assert np.allclose(df.sum(axis=1).to_numpy(), [1, 1])


def test_mean():
df = pd.DataFrame([[1, 2], [3, 4], [5, 6]], index=[3, 4, 4])
df = ivory.utils.data.mean(df)
assert list(df.index) == [3, 4]
assert np.allclose(df.to_numpy(), [[1, 2], [4, 5]])

s = pd.Series([1, 2, 3], index=[3, 4, 4])
s = ivory.utils.data.mean(s)
assert isinstance(s, pd.Series)
assert list(s.index) == [3, 4]
assert np.allclose(s.to_numpy(), [1, 2.5])


def test_argmax():
df = pd.DataFrame([[1, 2], [3, 4]], index=[3, 4])
s = ivory.utils.data.argmax(df)
print(s)
assert s.loc[3] == 1
assert s.loc[4] == 1

0 comments on commit 0ad9dd2

Please sign in to comment.