From 68f7b75e1c3f89ba28b86f18024ce20c0c4f4ec8 Mon Sep 17 00:00:00 2001 From: c-bata Date: Wed, 25 May 2022 20:06:29 +0900 Subject: [PATCH 1/5] Fix -inf and nan handling --- optuna_dashboard/_serializer.py | 46 ++++++++++++------- optuna_dashboard/ts/apiClient.ts | 2 +- optuna_dashboard/ts/components/GraphEdf.tsx | 3 +- .../ts/components/GraphHistory.tsx | 4 +- .../ts/components/GraphIntermediateValues.tsx | 4 +- .../ts/components/GraphParallelCoordinate.tsx | 4 +- .../ts/components/GraphParetoFront.tsx | 2 +- optuna_dashboard/ts/components/GraphSlice.tsx | 4 +- .../ts/components/StudyDetail.tsx | 27 ++++++++++- optuna_dashboard/ts/types/index.d.ts | 6 ++- 10 files changed, 76 insertions(+), 26 deletions(-) diff --git a/optuna_dashboard/_serializer.py b/optuna_dashboard/_serializer.py index b2dda478b..685252cc3 100644 --- a/optuna_dashboard/_serializer.py +++ b/optuna_dashboard/_serializer.py @@ -1,5 +1,4 @@ import json -import math from typing import Any from typing import Dict from typing import List @@ -9,14 +8,17 @@ from optuna.distributions import BaseDistribution from optuna.study import StudySummary from optuna.trial import FrozenTrial +import numpy as np from . import _note as note try: from typing import TypedDict + from typing import Literal except ImportError: from typing_extensions import TypedDict + from typing_extensions import Literal MAX_ATTR_LENGTH = 1024 @@ -31,7 +33,7 @@ "IntermediateValue", { "step": int, - "value": Union[float, str], + "value": Union[float, Literal["inf", "-inf", "nan"]], }, ) TrialParam = TypedDict( @@ -56,17 +58,6 @@ def serialize_attrs(attrs: Dict[str, Any]) -> List[Attribute]: return serialized -def serialize_intermediate_values(values: Dict[int, float]) -> List[IntermediateValue]: - return [ - {"step": step, "value": "inf" if math.isinf(value) else value} - for step, value in values.items() - ] - - -def serialize_trial_params(params: Dict[str, Any]) -> List[TrialParam]: - return [{"name": name, "value": str(value)} for name, value in params.items()] - - def serialize_study_summary(summary: StudySummary) -> Dict[str, Any]: serialized = { "study_id": summary._study_id, @@ -118,14 +109,37 @@ def serialize_frozen_trial(study_id: int, trial: FrozenTrial) -> Dict[str, Any]: "study_id": study_id, "number": trial.number, "state": trial.state.name.capitalize(), - "intermediate_values": serialize_intermediate_values(trial.intermediate_values), - "params": serialize_trial_params(trial.params), + "params": [ + {"name": name, "value": str(value)} for name, value in trial.params.items() + ], "user_attrs": serialize_attrs(trial.user_attrs), "system_attrs": serialize_attrs(trial.system_attrs), } + serialized_intermediate_values = [] + for step, value in trial.intermediate_values.items(): + if np.isnan(value): + serialized_value = "nan" + elif np.isposinf(value): + serialized_value = "inf" + elif np.isneginf(value): + serialized_value = "-inf" + else: + serialized_value = value + serialized_intermediate_values.append({"step": step, "value": serialized_value}) + serialized["intermediate_values"] = serialized_intermediate_values + if trial.values is not None: - serialized["values"] = ["inf" if math.isinf(v) else v for v in trial.values] + serialized_values: List[Union[float, Literal["inf", "-inf"]]] = [] + for v in trial.values: + assert not np.isnan(v), "Should not detect nan value" + if np.isposinf(v): + serialized_values.append("inf") + elif np.isneginf(v): + serialized_values.append("-inf") + else: + serialized_values.append(v) + serialized["values"] = serialized_values if trial.datetime_start is not None: serialized["datetime_start"] = trial.datetime_start.isoformat() diff --git a/optuna_dashboard/ts/apiClient.ts b/optuna_dashboard/ts/apiClient.ts index c6fa35e24..d0f300381 100644 --- a/optuna_dashboard/ts/apiClient.ts +++ b/optuna_dashboard/ts/apiClient.ts @@ -7,7 +7,7 @@ interface TrialResponse { study_id: number number: number state: TrialState - values?: (number | "inf")[] + values?: TrialValueNumber[] intermediate_values: TrialIntermediateValue[] datetime_start?: string datetime_complete?: string diff --git a/optuna_dashboard/ts/components/GraphEdf.tsx b/optuna_dashboard/ts/components/GraphEdf.tsx index 583a73ce3..47b82bee2 100644 --- a/optuna_dashboard/ts/components/GraphEdf.tsx +++ b/optuna_dashboard/ts/components/GraphEdf.tsx @@ -66,7 +66,8 @@ const filterFunc = (trial: Trial, objectiveId: number): boolean => { return ( trial.state === "Complete" && trial.values !== undefined && - trial.values[objectiveId] !== "inf" + trial.values[objectiveId] !== "inf" && + trial.values[objectiveId] !== "-inf" ) } diff --git a/optuna_dashboard/ts/components/GraphHistory.tsx b/optuna_dashboard/ts/components/GraphHistory.tsx index d050067c6..eabaa8f11 100644 --- a/optuna_dashboard/ts/components/GraphHistory.tsx +++ b/optuna_dashboard/ts/components/GraphHistory.tsx @@ -180,7 +180,9 @@ const filterFunc = (trial: Trial, objectiveId: number): boolean => { return false } return ( - trial.values.length > objectiveId && trial.values[objectiveId] !== "inf" + trial.values.length > objectiveId && + trial.values[objectiveId] !== "inf" && + trial.values[objectiveId] !== "-inf" ) } diff --git a/optuna_dashboard/ts/components/GraphIntermediateValues.tsx b/optuna_dashboard/ts/components/GraphIntermediateValues.tsx index e19cf070a..96a4a1f31 100644 --- a/optuna_dashboard/ts/components/GraphIntermediateValues.tsx +++ b/optuna_dashboard/ts/components/GraphIntermediateValues.tsx @@ -58,7 +58,9 @@ const plotIntermediateValue = (trials: Trial[], mode: string) => { t.state == "Running" ) const plotData: Partial[] = filteredTrials.map((trial) => { - const values = trial.intermediate_values.filter((iv) => iv.value !== "inf") + const values = trial.intermediate_values.filter( + (iv) => iv.value !== "inf" && iv.value !== "-inf" && iv.value !== "nan" + ) return { x: values.map((iv) => iv.step), y: values.map((iv) => iv.value), diff --git a/optuna_dashboard/ts/components/GraphParallelCoordinate.tsx b/optuna_dashboard/ts/components/GraphParallelCoordinate.tsx index 09b73666d..8667d8b43 100644 --- a/optuna_dashboard/ts/components/GraphParallelCoordinate.tsx +++ b/optuna_dashboard/ts/components/GraphParallelCoordinate.tsx @@ -71,7 +71,9 @@ const filterFunc = (trial: Trial, objectiveId: number): boolean => { return false } return ( - trial.values.length > objectiveId && trial.values[objectiveId] !== "inf" + trial.values.length > objectiveId && + trial.values[objectiveId] !== "inf" && + trial.values[objectiveId] !== "-inf" ) } diff --git a/optuna_dashboard/ts/components/GraphParetoFront.tsx b/optuna_dashboard/ts/components/GraphParetoFront.tsx index 48a8d56c3..0c9457110 100644 --- a/optuna_dashboard/ts/components/GraphParetoFront.tsx +++ b/optuna_dashboard/ts/components/GraphParetoFront.tsx @@ -90,7 +90,7 @@ const filterFunc = (trial: Trial, directions: StudyDirection[]): boolean => { trial.state === "Complete" && trial.values !== undefined && trial.values.length === directions.length && - trial.values.every((v) => v !== "inf") + trial.values.every((v) => v !== "inf" && v !== "-inf") ) } diff --git a/optuna_dashboard/ts/components/GraphSlice.tsx b/optuna_dashboard/ts/components/GraphSlice.tsx index 221bf6df1..d03dacf1a 100644 --- a/optuna_dashboard/ts/components/GraphSlice.tsx +++ b/optuna_dashboard/ts/components/GraphSlice.tsx @@ -130,7 +130,9 @@ const filterFunc = ( return false } return ( - trial.values.length > objectiveId && trial.values[objectiveId] !== "inf" + trial.values.length > objectiveId && + trial.values[objectiveId] !== "inf" && + trial.values[objectiveId] !== "-inf" ) } diff --git a/optuna_dashboard/ts/components/StudyDetail.tsx b/optuna_dashboard/ts/components/StudyDetail.tsx index d442cdc74..21f91de3a 100644 --- a/optuna_dashboard/ts/components/StudyDetail.tsx +++ b/optuna_dashboard/ts/components/StudyDetail.tsx @@ -647,7 +647,32 @@ export const TrialTable: FC<{ studyDetail: StudyDetail | null }> = ({ const collapseIntermediateValueColumns: DataGridColumn[] = [ { field: "step", label: "Step", sortable: true }, - { field: "value", label: "Value", sortable: true }, + { + field: "value", + label: "Value", + sortable: true, + less: (firstEl, secondEl): number => { + const firstVal = firstEl.value + const secondVal = secondEl.value + if (firstVal === secondVal) { + return 0 + } + if ( + firstVal === "-inf" || + secondVal === "nan" || + secondVal === "inf" + ) { + return 1 + } else if ( + secondVal === "-inf" || + firstVal === "nan" || + firstVal === "inf" + ) { + return -1 + } + return firstVal < secondVal ? 1 : -1 + }, + }, ] const collapseAttrColumns: DataGridColumn[] = [ { field: "key", label: "Key", sortable: true }, diff --git a/optuna_dashboard/ts/types/index.d.ts b/optuna_dashboard/ts/types/index.d.ts index 7af525738..951e858e1 100644 --- a/optuna_dashboard/ts/types/index.d.ts +++ b/optuna_dashboard/ts/types/index.d.ts @@ -7,6 +7,8 @@ declare const APP_BAR_TITLE: string declare const API_ENDPOINT: string declare const URL_PREFIX: string +type TrialValueNumber = number | "inf" | "-inf" +type TrialIntermediateValueNumber = number | "inf" | "-inf" | "nan" type TrialState = "Running" | "Complete" | "Pruned" | "Fail" | "Waiting" type StudyDirection = "maximize" | "minimize" | "not_set" type Distribution = @@ -21,7 +23,7 @@ type Distribution = declare interface TrialIntermediateValue { step: number - value: number | "inf" + value: TrialIntermediateValueNumber } declare interface TrialParam { @@ -55,7 +57,7 @@ declare interface Trial { study_id: number number: number state: TrialState - values?: (number | "inf")[] + values?: TrialValueNumber[] intermediate_values: TrialIntermediateValue[] datetime_start?: Date datetime_complete?: Date From f46a1d724dd4811cebe3e21b1f0f119bb64df40d Mon Sep 17 00:00:00 2001 From: c-bata Date: Wed, 25 May 2022 21:16:38 +0900 Subject: [PATCH 2/5] Update visual regression test --- optuna_dashboard/_serializer.py | 7 ++++--- visual_regression_test.py | 18 +++++++++++++----- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/optuna_dashboard/_serializer.py b/optuna_dashboard/_serializer.py index 685252cc3..f1e40a743 100644 --- a/optuna_dashboard/_serializer.py +++ b/optuna_dashboard/_serializer.py @@ -5,20 +5,20 @@ from typing import Tuple from typing import Union +import numpy as np from optuna.distributions import BaseDistribution from optuna.study import StudySummary from optuna.trial import FrozenTrial -import numpy as np from . import _note as note try: - from typing import TypedDict from typing import Literal + from typing import TypedDict except ImportError: - from typing_extensions import TypedDict from typing_extensions import Literal + from typing_extensions import TypedDict MAX_ATTR_LENGTH = 1024 @@ -125,6 +125,7 @@ def serialize_frozen_trial(study_id: int, trial: FrozenTrial) -> Dict[str, Any]: elif np.isneginf(value): serialized_value = "-inf" else: + assert np.isfinite(value) serialized_value = value serialized_intermediate_values.append({"step": step, "value": serialized_value}) serialized["intermediate_values"] = serialized_intermediate_values diff --git a/visual_regression_test.py b/visual_regression_test.py index fbe639bbf..f601a81fd 100644 --- a/visual_regression_test.py +++ b/visual_regression_test.py @@ -87,13 +87,15 @@ def objective_single_dynamic(trial: optuna.Trial) -> float: study.optimize(objective_single_dynamic, n_trials=50) - # Single objective study with 'inf' value + # Single objective study with 'inf', '-inf', or 'nan' value study = optuna.create_study(study_name="single-inf", storage=storage) def objective_single_inf(trial: optuna.Trial) -> float: x = trial.suggest_float("x", -10, 10) - if x > 0: - return math.inf + if trial.number % 3 == 0: + return float("inf") + elif trial.number % 3 == 1: + return float("-inf") else: return x**2 @@ -152,13 +154,19 @@ def objective_prune_without_report(trial: optuna.Trial) -> float: study.optimize(objective_prune_without_report, n_trials=100) - # Single objective pruned after reported 'inf' value + # Single objective pruned after reported 'inf', '-inf', or 'nan' study = optuna.create_study(study_name="single-inf-report", storage=storage) def objective_single_inf_report(trial: optuna.Trial) -> float: x = trial.suggest_float("x", -10, 10) + if trial.number % 3 == 0: + trial.report(float("inf"), 1) + elif trial.number % 3 == 1: + trial.report(float("-inf"), 1) + else: + trial.report(float("nan"), 1) + if x > 0: - trial.report(math.inf, 1) raise optuna.TrialPruned() else: return x**2 From 22d4207c16467759c3e0b7addecca1e2da8426f4 Mon Sep 17 00:00:00 2001 From: c-bata Date: Wed, 25 May 2022 21:28:30 +0900 Subject: [PATCH 3/5] Fix comparator function --- .../ts/components/StudyDetail.tsx | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/optuna_dashboard/ts/components/StudyDetail.tsx b/optuna_dashboard/ts/components/StudyDetail.tsx index 21f91de3a..1ea24675b 100644 --- a/optuna_dashboard/ts/components/StudyDetail.tsx +++ b/optuna_dashboard/ts/components/StudyDetail.tsx @@ -515,13 +515,18 @@ export const TrialTable: FC<{ studyDetail: StudyDetail | null }> = ({ if (firstVal === secondVal) { return 0 - } else if (firstVal && secondVal) { - return firstVal < secondVal ? 1 : -1 - } else if (firstVal) { + } + if (firstVal === undefined) { return -1 - } else { + } else if (secondVal === undefined) { return 1 } + if (firstVal === "-inf" || secondVal === "inf") { + return 1 + } else if (secondVal === "-inf" || firstVal === "inf") { + return -1 + } + return firstVal < secondVal ? 1 : -1 }, toCellValue: (i) => { if (trials[i].values === undefined) { @@ -542,13 +547,18 @@ export const TrialTable: FC<{ studyDetail: StudyDetail | null }> = ({ if (firstVal === secondVal) { return 0 - } else if (firstVal && secondVal) { - return firstVal < secondVal ? 1 : -1 - } else if (firstVal) { + } + if (firstVal === undefined) { return -1 - } else { + } else if (secondVal === undefined) { return 1 } + if (firstVal === "-inf" || secondVal === "inf") { + return 1 + } else if (secondVal === "-inf" || firstVal === "inf") { + return -1 + } + return firstVal < secondVal ? 1 : -1 }, toCellValue: (i) => { if (trials[i].values === undefined) { From d1ced4eb9551d8a0f8b53727af34ebb072966f5f Mon Sep 17 00:00:00 2001 From: c-bata Date: Wed, 25 May 2022 21:31:50 +0900 Subject: [PATCH 4/5] Fix comparator function for TrialIntermediateValue table --- optuna_dashboard/ts/components/StudyDetail.tsx | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/optuna_dashboard/ts/components/StudyDetail.tsx b/optuna_dashboard/ts/components/StudyDetail.tsx index 1ea24675b..feb9a88cc 100644 --- a/optuna_dashboard/ts/components/StudyDetail.tsx +++ b/optuna_dashboard/ts/components/StudyDetail.tsx @@ -667,17 +667,14 @@ export const TrialTable: FC<{ studyDetail: StudyDetail | null }> = ({ if (firstVal === secondVal) { return 0 } - if ( - firstVal === "-inf" || - secondVal === "nan" || - secondVal === "inf" - ) { + if (firstVal === "nan") { + return -1 + } else if (secondVal === "nan") { return 1 - } else if ( - secondVal === "-inf" || - firstVal === "nan" || - firstVal === "inf" - ) { + } + if (firstVal === "-inf" || secondVal === "inf") { + return 1 + } else if (secondVal === "-inf" || firstVal === "inf") { return -1 } return firstVal < secondVal ? 1 : -1 From c9975b9486f809404672a9bffea03807c2753308 Mon Sep 17 00:00:00 2001 From: c-bata Date: Wed, 25 May 2022 21:43:14 +0900 Subject: [PATCH 5/5] Fix lint errors --- optuna_dashboard/_serializer.py | 12 +++--------- visual_regression_test.py | 1 - 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/optuna_dashboard/_serializer.py b/optuna_dashboard/_serializer.py index f1e40a743..e938221fa 100644 --- a/optuna_dashboard/_serializer.py +++ b/optuna_dashboard/_serializer.py @@ -17,7 +17,7 @@ from typing import Literal from typing import TypedDict except ImportError: - from typing_extensions import Literal + from typing_extensions import Literal # type: ignore from typing_extensions import TypedDict @@ -36,13 +36,6 @@ "value": Union[float, Literal["inf", "-inf", "nan"]], }, ) -TrialParam = TypedDict( - "TrialParam", - { - "name": str, - "value": str, - }, -) def serialize_attrs(attrs: Dict[str, Any]) -> List[Attribute]: @@ -116,8 +109,9 @@ def serialize_frozen_trial(study_id: int, trial: FrozenTrial) -> Dict[str, Any]: "system_attrs": serialize_attrs(trial.system_attrs), } - serialized_intermediate_values = [] + serialized_intermediate_values: List[IntermediateValue] = [] for step, value in trial.intermediate_values.items(): + serialized_value: Union[float, Literal["nan", "inf", "-inf"]] if np.isnan(value): serialized_value = "nan" elif np.isposinf(value): diff --git a/visual_regression_test.py b/visual_regression_test.py index f601a81fd..ae62f0134 100644 --- a/visual_regression_test.py +++ b/visual_regression_test.py @@ -1,6 +1,5 @@ import argparse import asyncio -import math import os import sys import threading