diff --git a/optuna_dashboard/ts/action.ts b/optuna_dashboard/ts/action.ts index 6e112f05b..fdf8b6abd 100644 --- a/optuna_dashboard/ts/action.ts +++ b/optuna_dashboard/ts/action.ts @@ -23,6 +23,7 @@ import { artifactIsAvailable, reloadIntervalState, } from "./state" +import { getDominatedTrials } from "./dominatedTrials" const localStorageGraphVisibility = "graphVisibility" const localStorageReloadInterval = "reloadInterval" @@ -119,6 +120,34 @@ export const actionCreator = () => { newTrials[index] = newTrial const newStudy: StudyDetail = Object.assign({}, studyDetails[studyId]) newStudy.trials = newTrials + + // Update Best Trials + if (state === "Complete" && newStudy.directions.length === 1) { + // Single objective optimization + const bestValue = newStudy.best_trials.at(0)?.values?.at(0) + const currentValue = values?.at(0) + if (newStudy.best_trials.length === 0) { + newStudy.best_trials.push(newTrial) + } else if (bestValue !== undefined && currentValue !== undefined) { + if (newStudy.directions[0] === "minimize" && currentValue < bestValue) { + newStudy.best_trials = [newTrial] + } else if ( + newStudy.directions[0] === "maximize" && + currentValue > bestValue + ) { + newStudy.best_trials = [newTrial] + } else if (currentValue == bestValue) { + newStudy.best_trials.push(newTrial) + } + } + } else if (state === "Complete") { + // Multi objective optimization + newStudy.best_trials = getDominatedTrials( + newStudy.trials, + newStudy.directions + ) + } + setStudyDetailState(studyId, newStudy) } diff --git a/optuna_dashboard/ts/dominatedTrials.ts b/optuna_dashboard/ts/dominatedTrials.ts new file mode 100644 index 000000000..632e2cf73 --- /dev/null +++ b/optuna_dashboard/ts/dominatedTrials.ts @@ -0,0 +1,40 @@ +const filterFunc = (trial: Trial, directions: StudyDirection[]): boolean => { + return ( + trial.state === "Complete" && + trial.values !== undefined && + trial.values.length === directions.length && + trial.values.every((v) => v !== "inf" && v !== "-inf") + ) +} + +export const getDominatedTrials = ( + trials: Trial[], + directions: StudyDirection[] +): Trial[] => { + // TODO(c-bata): Use log-linear algorithm like Optuna. + // TODO(c-bata): Use this function at GraphParetoFront. + const filteredTrials = trials.filter((t: Trial) => filterFunc(t, directions)) + + const normalizedValues: number[][] = [] + filteredTrials.forEach((t) => { + if (t.values && t.values.length === directions.length) { + const trialValues = t.values.map((v, i) => { + return directions[i] === "minimize" ? (v as number) : (-v as number) + }) + normalizedValues.push(trialValues) + } + }) + const dominatedTrials: boolean[] = [] + normalizedValues.forEach((values0: number[], i: number) => { + const dominated = normalizedValues.some((values1: number[], j: number) => { + if (i === j) { + return false + } + return values0.every((value0: number, k: number) => { + return values1[k] <= value0 + }) + }) + dominatedTrials.push(dominated) + }) + return filteredTrials.filter((_, i) => !dominatedTrials.at(i)) +}