Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug when given -inf or nan values #237

Merged
merged 5 commits into from
May 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 32 additions & 23 deletions optuna_dashboard/_serializer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import json
import math
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union

import numpy as np
from optuna.distributions import BaseDistribution
from optuna.study import StudySummary
from optuna.trial import FrozenTrial
Expand All @@ -14,8 +14,10 @@


try:
from typing import Literal
from typing import TypedDict
except ImportError:
from typing_extensions import Literal # type: ignore
from typing_extensions import TypedDict


Expand All @@ -31,14 +33,7 @@
"IntermediateValue",
{
"step": int,
"value": Union[float, str],
},
)
TrialParam = TypedDict(
"TrialParam",
{
"name": str,
"value": str,
"value": Union[float, Literal["inf", "-inf", "nan"]],
},
)

Expand All @@ -56,17 +51,6 @@ def serialize_attrs(attrs: Dict[str, Any]) -> List[Attribute]:
return serialized


def serialize_intermediate_values(values: Dict[int, float]) -> List[IntermediateValue]:
return [
{"step": step, "value": "inf" if math.isinf(value) else value}
for step, value in values.items()
]


def serialize_trial_params(params: Dict[str, Any]) -> List[TrialParam]:
return [{"name": name, "value": str(value)} for name, value in params.items()]


def serialize_study_summary(summary: StudySummary) -> Dict[str, Any]:
serialized = {
"study_id": summary._study_id,
Expand Down Expand Up @@ -118,14 +102,39 @@ def serialize_frozen_trial(study_id: int, trial: FrozenTrial) -> Dict[str, Any]:
"study_id": study_id,
"number": trial.number,
"state": trial.state.name.capitalize(),
"intermediate_values": serialize_intermediate_values(trial.intermediate_values),
"params": serialize_trial_params(trial.params),
"params": [
{"name": name, "value": str(value)} for name, value in trial.params.items()
],
"user_attrs": serialize_attrs(trial.user_attrs),
"system_attrs": serialize_attrs(trial.system_attrs),
}

serialized_intermediate_values: List[IntermediateValue] = []
for step, value in trial.intermediate_values.items():
serialized_value: Union[float, Literal["nan", "inf", "-inf"]]
if np.isnan(value):
serialized_value = "nan"
elif np.isposinf(value):
serialized_value = "inf"
elif np.isneginf(value):
serialized_value = "-inf"
else:
assert np.isfinite(value)
serialized_value = value
serialized_intermediate_values.append({"step": step, "value": serialized_value})
serialized["intermediate_values"] = serialized_intermediate_values

if trial.values is not None:
serialized["values"] = ["inf" if math.isinf(v) else v for v in trial.values]
serialized_values: List[Union[float, Literal["inf", "-inf"]]] = []
for v in trial.values:
assert not np.isnan(v), "Should not detect nan value"
if np.isposinf(v):
serialized_values.append("inf")
elif np.isneginf(v):
serialized_values.append("-inf")
else:
serialized_values.append(v)
serialized["values"] = serialized_values

if trial.datetime_start is not None:
serialized["datetime_start"] = trial.datetime_start.isoformat()
Expand Down
2 changes: 1 addition & 1 deletion optuna_dashboard/ts/apiClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ interface TrialResponse {
study_id: number
number: number
state: TrialState
values?: (number | "inf")[]
values?: TrialValueNumber[]
intermediate_values: TrialIntermediateValue[]
datetime_start?: string
datetime_complete?: string
Expand Down
3 changes: 2 additions & 1 deletion optuna_dashboard/ts/components/GraphEdf.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ const filterFunc = (trial: Trial, objectiveId: number): boolean => {
return (
trial.state === "Complete" &&
trial.values !== undefined &&
trial.values[objectiveId] !== "inf"
trial.values[objectiveId] !== "inf" &&
trial.values[objectiveId] !== "-inf"
)
}

Expand Down
4 changes: 3 additions & 1 deletion optuna_dashboard/ts/components/GraphHistory.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ const filterFunc = (trial: Trial, objectiveId: number): boolean => {
return false
}
return (
trial.values.length > objectiveId && trial.values[objectiveId] !== "inf"
trial.values.length > objectiveId &&
trial.values[objectiveId] !== "inf" &&
trial.values[objectiveId] !== "-inf"
)
}

Expand Down
4 changes: 3 additions & 1 deletion optuna_dashboard/ts/components/GraphIntermediateValues.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ const plotIntermediateValue = (trials: Trial[], mode: string) => {
t.state == "Running"
)
const plotData: Partial<plotly.PlotData>[] = filteredTrials.map((trial) => {
const values = trial.intermediate_values.filter((iv) => iv.value !== "inf")
const values = trial.intermediate_values.filter(
(iv) => iv.value !== "inf" && iv.value !== "-inf" && iv.value !== "nan"
)
return {
x: values.map((iv) => iv.step),
y: values.map((iv) => iv.value),
Expand Down
4 changes: 3 additions & 1 deletion optuna_dashboard/ts/components/GraphParallelCoordinate.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ const filterFunc = (trial: Trial, objectiveId: number): boolean => {
return false
}
return (
trial.values.length > objectiveId && trial.values[objectiveId] !== "inf"
trial.values.length > objectiveId &&
trial.values[objectiveId] !== "inf" &&
trial.values[objectiveId] !== "-inf"
)
}

Expand Down
2 changes: 1 addition & 1 deletion optuna_dashboard/ts/components/GraphParetoFront.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ const filterFunc = (trial: Trial, directions: StudyDirection[]): boolean => {
trial.state === "Complete" &&
trial.values !== undefined &&
trial.values.length === directions.length &&
trial.values.every((v) => v !== "inf")
trial.values.every((v) => v !== "inf" && v !== "-inf")
)
}

Expand Down
4 changes: 3 additions & 1 deletion optuna_dashboard/ts/components/GraphSlice.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ const filterFunc = (
return false
}
return (
trial.values.length > objectiveId && trial.values[objectiveId] !== "inf"
trial.values.length > objectiveId &&
trial.values[objectiveId] !== "inf" &&
trial.values[objectiveId] !== "-inf"
)
}

Expand Down
50 changes: 41 additions & 9 deletions optuna_dashboard/ts/components/StudyDetail.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -515,13 +515,18 @@ export const TrialTable: FC<{ studyDetail: StudyDetail | null }> = ({

if (firstVal === secondVal) {
return 0
} else if (firstVal && secondVal) {
return firstVal < secondVal ? 1 : -1
} else if (firstVal) {
}
if (firstVal === undefined) {
return -1
} else {
} else if (secondVal === undefined) {
return 1
}
if (firstVal === "-inf" || secondVal === "inf") {
return 1
} else if (secondVal === "-inf" || firstVal === "inf") {
return -1
}
return firstVal < secondVal ? 1 : -1
},
toCellValue: (i) => {
if (trials[i].values === undefined) {
Expand All @@ -542,13 +547,18 @@ export const TrialTable: FC<{ studyDetail: StudyDetail | null }> = ({

if (firstVal === secondVal) {
return 0
} else if (firstVal && secondVal) {
return firstVal < secondVal ? 1 : -1
} else if (firstVal) {
}
if (firstVal === undefined) {
return -1
} else {
} else if (secondVal === undefined) {
return 1
}
if (firstVal === "-inf" || secondVal === "inf") {
return 1
} else if (secondVal === "-inf" || firstVal === "inf") {
return -1
}
return firstVal < secondVal ? 1 : -1
},
toCellValue: (i) => {
if (trials[i].values === undefined) {
Expand Down Expand Up @@ -647,7 +657,29 @@ export const TrialTable: FC<{ studyDetail: StudyDetail | null }> = ({
const collapseIntermediateValueColumns: DataGridColumn<TrialIntermediateValue>[] =
[
{ field: "step", label: "Step", sortable: true },
{ field: "value", label: "Value", sortable: true },
{
field: "value",
label: "Value",
sortable: true,
less: (firstEl, secondEl): number => {
const firstVal = firstEl.value
const secondVal = secondEl.value
if (firstVal === secondVal) {
return 0
}
if (firstVal === "nan") {
return -1
} else if (secondVal === "nan") {
return 1
}
if (firstVal === "-inf" || secondVal === "inf") {
return 1
} else if (secondVal === "-inf" || firstVal === "inf") {
return -1
}
return firstVal < secondVal ? 1 : -1
},
},
]
const collapseAttrColumns: DataGridColumn<Attribute>[] = [
{ field: "key", label: "Key", sortable: true },
Expand Down
6 changes: 4 additions & 2 deletions optuna_dashboard/ts/types/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ declare const APP_BAR_TITLE: string
declare const API_ENDPOINT: string
declare const URL_PREFIX: string

type TrialValueNumber = number | "inf" | "-inf"
type TrialIntermediateValueNumber = number | "inf" | "-inf" | "nan"
type TrialState = "Running" | "Complete" | "Pruned" | "Fail" | "Waiting"
type StudyDirection = "maximize" | "minimize" | "not_set"
type Distribution =
Expand All @@ -21,7 +23,7 @@ type Distribution =

declare interface TrialIntermediateValue {
step: number
value: number | "inf"
value: TrialIntermediateValueNumber
}

declare interface TrialParam {
Expand Down Expand Up @@ -55,7 +57,7 @@ declare interface Trial {
study_id: number
number: number
state: TrialState
values?: (number | "inf")[]
values?: TrialValueNumber[]
intermediate_values: TrialIntermediateValue[]
datetime_start?: Date
datetime_complete?: Date
Expand Down
19 changes: 13 additions & 6 deletions visual_regression_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
import asyncio
import math
import os
import sys
import threading
Expand Down Expand Up @@ -87,13 +86,15 @@ def objective_single_dynamic(trial: optuna.Trial) -> float:

study.optimize(objective_single_dynamic, n_trials=50)

# Single objective study with 'inf' value
# Single objective study with 'inf', '-inf', or 'nan' value
study = optuna.create_study(study_name="single-inf", storage=storage)

def objective_single_inf(trial: optuna.Trial) -> float:
x = trial.suggest_float("x", -10, 10)
if x > 0:
return math.inf
if trial.number % 3 == 0:
return float("inf")
elif trial.number % 3 == 1:
return float("-inf")
else:
return x**2

Expand Down Expand Up @@ -152,13 +153,19 @@ def objective_prune_without_report(trial: optuna.Trial) -> float:

study.optimize(objective_prune_without_report, n_trials=100)

# Single objective pruned after reported 'inf' value
# Single objective pruned after reported 'inf', '-inf', or 'nan'
study = optuna.create_study(study_name="single-inf-report", storage=storage)

def objective_single_inf_report(trial: optuna.Trial) -> float:
x = trial.suggest_float("x", -10, 10)
if trial.number % 3 == 0:
trial.report(float("inf"), 1)
elif trial.number % 3 == 1:
trial.report(float("-inf"), 1)
else:
trial.report(float("nan"), 1)

if x > 0:
trial.report(math.inf, 1)
raise optuna.TrialPruned()
else:
return x**2
Expand Down