Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hyperparameter importance chart #54

Merged
merged 9 commits into from Mar 14, 2021
Merged
44 changes: 42 additions & 2 deletions optuna_dashboard/app.py
Expand Up @@ -8,10 +8,11 @@
from typing import Union, Dict, List, Optional, TypeVar, Callable, Any, cast

from bottle import Bottle, BaseResponse, redirect, request, response, static_file
import optuna
from optuna.exceptions import DuplicatedStudyError
from optuna.storages import BaseStorage
from optuna.trial import FrozenTrial
from optuna.study import StudyDirection, StudySummary
from optuna.trial import FrozenTrial, TrialState
from optuna.study import StudyDirection, StudySummary, Study

from . import serializer

Expand Down Expand Up @@ -94,6 +95,13 @@ def get_trials(
return trials


def get_distribution_name(param_name: str, study: Study) -> str:
for trial in study.trials:
if param_name in trial.distributions:
return trial.distributions[param_name].__class__.__name__
assert False, "Must not reach here."


def create_app(storage: BaseStorage) -> Bottle:
app = Bottle()

Expand Down Expand Up @@ -185,6 +193,38 @@ def get_study_detail(study_id: int) -> BottleViewReturn:
trials = get_trials(storage, study_id)
return serializer.serialize_study_detail(summary, trials)

@app.get("/api/studies/<study_id:int>/param_importances")
@handle_json_api_exception
def get_param_importances(study_id: int) -> BottleViewReturn:
# TODO(chenghuzi): add support for selecting params and targets via query parameters.
response.content_type = "application/json"
study_name = storage.get_study_name_from_id(study_id)
study = Study(study_name=study_name, storage=storage)

trials = [trial for trial in study.trials if trial.state == TrialState.COMPLETE]
if len(trials) == 0:
return ""
evaluator = None
params = None
target = None
importances = optuna.importance.get_param_importances(
c-bata marked this conversation as resolved.
Show resolved Hide resolved
study, evaluator=evaluator, params=params, target=target
)
if target is None:
target_name = "Objective Value"

return {
"target_name": target_name,
"param_importances": [
{
"name": name,
"importance": importance,
"distribution": get_distribution_name(name, study),
}
for name, importance in importances.items()
],
}

@app.get("/static/<filename:path>")
def send_static(filename: str) -> BottleViewReturn:
return static_file(filename, root=STATIC_DIR)
Expand Down
18 changes: 18 additions & 0 deletions optuna_dashboard/static/apiClient.ts
Expand Up @@ -168,3 +168,21 @@ export const deleteStudyAPI = (studyId: number) => {
return {}
})
}

interface ParamImportancesResponse {
target_name: string
param_importances: ParamImportance[]
}

export const getParamImportances = (
studyId: number
): Promise<ParamImportances> => {
return axiosInstance
.get<ParamImportancesResponse>(
`/api/studies/${studyId}/param_importances`,
{}
)
.then((res) => {
return res.data
})
}
88 changes: 88 additions & 0 deletions optuna_dashboard/static/components/HyperparameterImportances.tsx
@@ -0,0 +1,88 @@
import * as plotly from "plotly.js-dist"
import React, { FC, useEffect } from "react"
import { getParamImportances } from "../apiClient"
const plotDomId = "graph-hyperparameter-importances"

// To match colors used by plot_param_importances in optuna.
const plotlyColorsSequentialBlues = [
"rgb(247,251,255)",
"rgb(222,235,247)",
"rgb(198,219,239)",
"rgb(158,202,225)",
"rgb(107,174,214)",
"rgb(66,146,198)",
"rgb(33,113,181)",
"rgb(8,81,156)",
"rgb(8,48,107)",
]

const distributionColors = {
UniformDistribution: plotlyColorsSequentialBlues.slice(-1)[0],
LogUniformDistribution: plotlyColorsSequentialBlues.slice(-1)[0],
DiscreteUniformDistribution: plotlyColorsSequentialBlues.slice(-1)[0],
IntUniformDistribution: plotlyColorsSequentialBlues.slice(-2)[0],
IntLogUniformDistribution: plotlyColorsSequentialBlues.slice(-2)[0],
CategoricalDistribution: plotlyColorsSequentialBlues.slice(-4)[0],
}

export const HyperparameterImportances: FC<{
studyId: number
numOfTrials: number
}> = ({ studyId, numOfTrials = 0 }) => {
useEffect(() => {
async function fetchAndPlotParamImportances(studyId: number) {
const paramsImportanceData = await getParamImportances(studyId)
plotParamImportances(paramsImportanceData)
}
fetchAndPlotParamImportances(studyId)
}, [numOfTrials])
return <div id={plotDomId} />
}

const plotParamImportances = (paramsImportanceData: ParamImportances) => {
if (document.getElementById(plotDomId) === null) {
return
}
const param_importances = paramsImportanceData.param_importances.reverse()
const importance_values = param_importances.map((p) => p.importance)
const param_names = param_importances.map((p) => p.name)
const param_colors = param_importances.map(
(p) => distributionColors[p.distribution]
)
const param_hover_templates = param_importances.map(
(p) => `${p.name} (${p.distribution}): ${p.importance} <extra></extra>`
)

const layout: Partial<plotly.Layout> = {
title: "Hyperparameter Importance",
xaxis: {
title: `Importance for ${paramsImportanceData.target_name}`,
},
yaxis: {
title: "Hyperparameter",
},
margin: {
l: 50,
r: 50,
b: 50,
},
showlegend: false,
}

const plotData: Partial<plotly.PlotData>[] = [
{
type: "bar",
orientation: "h",
x: importance_values,
y: param_names,
text: importance_values.map((v) => String(v.toFixed(2))),
textposition: "outside",
hovertemplate: param_hover_templates,
marker: {
color: param_colors,
},
},
]

plotly.react(plotDomId, plotData, layout)
}
15 changes: 15 additions & 0 deletions optuna_dashboard/static/components/StudyDetail.tsx
Expand Up @@ -20,6 +20,7 @@ import { Home, Cached } from "@material-ui/icons"

import { DataGridColumn, DataGrid } from "./DataGrid"
import { GraphParallelCoordinate } from "./GraphParallelCoordinate"
import { HyperparameterImportances } from "./HyperparameterImportances"
import { GraphIntermediateValues } from "./GraphIntermediateValues"
import { GraphSlice } from "./GraphSlice"
import { GraphHistory } from "./GraphHistory"
Expand Down Expand Up @@ -196,6 +197,20 @@ export const StudyDetail: FC = () => {
</Grid>
</Grid>
) : null}
{studyDetail !== null && isSingleObjectiveStudy(studyDetail) ? (
<Grid container direction="row">
<Grid item xs={6}>
<Card className={classes.card}>
<CardContent>
<HyperparameterImportances
studyId={studyIdNumber}
numOfTrials={trials.length}
/>
</CardContent>
</Card>
</Grid>
</Grid>
) : null}
{studyDetail !== null ? (
<Card className={classes.card}>
<CardContent>
Expand Down
18 changes: 18 additions & 0 deletions optuna_dashboard/static/types/index.d.ts
Expand Up @@ -9,6 +9,13 @@ declare const URL_PREFIX: string

type TrialState = "Running" | "Complete" | "Pruned" | "Fail" | "Waiting"
type StudyDirection = "maximize" | "minimize" | "not_set"
type Distribution =
| "UniformDistribution"
| "LogUniformDistribution"
| "DiscreteUniformDistribution"
| "IntUniformDistribution"
| "IntLogUniformDistribution"
| "CategoricalDistribution"

declare interface TrialIntermediateValue {
step: number
Expand All @@ -20,6 +27,12 @@ declare interface TrialParam {
value: string
}

declare interface ParamImportance {
name: string
importance: number
distribution: Distribution
}

declare interface Attribute {
key: string
value: string
Expand Down Expand Up @@ -60,3 +73,8 @@ declare interface StudyDetail {
declare interface StudyDetails {
[study_id: string]: StudyDetail
}

declare interface ParamImportances {
target_name: string
param_importances: ParamImportance[]
}
1 change: 1 addition & 0 deletions setup.cfg
Expand Up @@ -30,6 +30,7 @@ install_requires =
optuna>=2.4
bottle
typing-extensions;python_version<'3.8'
scikit-learn

[options.extras_require]
lint =
Expand Down