Skip to content

Commit

Permalink
Merge pull request #54 from chenghuzi/plot_param_importances
Browse files Browse the repository at this point in the history
Add hyperparameter importance chart
  • Loading branch information
c-bata committed Mar 14, 2021
2 parents 4e1f50b + be2af8e commit 03e176b
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 2 deletions.
44 changes: 42 additions & 2 deletions optuna_dashboard/app.py
Expand Up @@ -8,10 +8,11 @@
from typing import Union, Dict, List, Optional, TypeVar, Callable, Any, cast

from bottle import Bottle, BaseResponse, redirect, request, response, static_file
import optuna
from optuna.exceptions import DuplicatedStudyError
from optuna.storages import BaseStorage
from optuna.trial import FrozenTrial
from optuna.study import StudyDirection, StudySummary
from optuna.trial import FrozenTrial, TrialState
from optuna.study import StudyDirection, StudySummary, Study

from . import serializer

Expand Down Expand Up @@ -94,6 +95,13 @@ def get_trials(
return trials


def get_distribution_name(param_name: str, study: Study) -> str:
for trial in study.trials:
if param_name in trial.distributions:
return trial.distributions[param_name].__class__.__name__
assert False, "Must not reach here."


def create_app(storage: BaseStorage) -> Bottle:
app = Bottle()

Expand Down Expand Up @@ -185,6 +193,38 @@ def get_study_detail(study_id: int) -> BottleViewReturn:
trials = get_trials(storage, study_id)
return serializer.serialize_study_detail(summary, trials)

@app.get("/api/studies/<study_id:int>/param_importances")
@handle_json_api_exception
def get_param_importances(study_id: int) -> BottleViewReturn:
# TODO(chenghuzi): add support for selecting params and targets via query parameters.
response.content_type = "application/json"
study_name = storage.get_study_name_from_id(study_id)
study = Study(study_name=study_name, storage=storage)

trials = [trial for trial in study.trials if trial.state == TrialState.COMPLETE]
if len(trials) == 0:
return ""
evaluator = None
params = None
target = None
importances = optuna.importance.get_param_importances(
study, evaluator=evaluator, params=params, target=target
)
if target is None:
target_name = "Objective Value"

return {
"target_name": target_name,
"param_importances": [
{
"name": name,
"importance": importance,
"distribution": get_distribution_name(name, study),
}
for name, importance in importances.items()
],
}

@app.get("/static/<filename:path>")
def send_static(filename: str) -> BottleViewReturn:
return static_file(filename, root=STATIC_DIR)
Expand Down
18 changes: 18 additions & 0 deletions optuna_dashboard/static/apiClient.ts
Expand Up @@ -168,3 +168,21 @@ export const deleteStudyAPI = (studyId: number) => {
return {}
})
}

interface ParamImportancesResponse {
target_name: string
param_importances: ParamImportance[]
}

export const getParamImportances = (
studyId: number
): Promise<ParamImportances> => {
return axiosInstance
.get<ParamImportancesResponse>(
`/api/studies/${studyId}/param_importances`,
{}
)
.then((res) => {
return res.data
})
}
88 changes: 88 additions & 0 deletions optuna_dashboard/static/components/HyperparameterImportances.tsx
@@ -0,0 +1,88 @@
import * as plotly from "plotly.js-dist"
import React, { FC, useEffect } from "react"
import { getParamImportances } from "../apiClient"
const plotDomId = "graph-hyperparameter-importances"

// To match colors used by plot_param_importances in optuna.
const plotlyColorsSequentialBlues = [
"rgb(247,251,255)",
"rgb(222,235,247)",
"rgb(198,219,239)",
"rgb(158,202,225)",
"rgb(107,174,214)",
"rgb(66,146,198)",
"rgb(33,113,181)",
"rgb(8,81,156)",
"rgb(8,48,107)",
]

const distributionColors = {
UniformDistribution: plotlyColorsSequentialBlues.slice(-1)[0],
LogUniformDistribution: plotlyColorsSequentialBlues.slice(-1)[0],
DiscreteUniformDistribution: plotlyColorsSequentialBlues.slice(-1)[0],
IntUniformDistribution: plotlyColorsSequentialBlues.slice(-2)[0],
IntLogUniformDistribution: plotlyColorsSequentialBlues.slice(-2)[0],
CategoricalDistribution: plotlyColorsSequentialBlues.slice(-4)[0],
}

export const HyperparameterImportances: FC<{
studyId: number
numOfTrials: number
}> = ({ studyId, numOfTrials = 0 }) => {
useEffect(() => {
async function fetchAndPlotParamImportances(studyId: number) {
const paramsImportanceData = await getParamImportances(studyId)
plotParamImportances(paramsImportanceData)
}
fetchAndPlotParamImportances(studyId)
}, [numOfTrials])
return <div id={plotDomId} />
}

const plotParamImportances = (paramsImportanceData: ParamImportances) => {
if (document.getElementById(plotDomId) === null) {
return
}
const param_importances = paramsImportanceData.param_importances.reverse()
const importance_values = param_importances.map((p) => p.importance)
const param_names = param_importances.map((p) => p.name)
const param_colors = param_importances.map(
(p) => distributionColors[p.distribution]
)
const param_hover_templates = param_importances.map(
(p) => `${p.name} (${p.distribution}): ${p.importance} <extra></extra>`
)

const layout: Partial<plotly.Layout> = {
title: "Hyperparameter Importance",
xaxis: {
title: `Importance for ${paramsImportanceData.target_name}`,
},
yaxis: {
title: "Hyperparameter",
},
margin: {
l: 50,
r: 50,
b: 50,
},
showlegend: false,
}

const plotData: Partial<plotly.PlotData>[] = [
{
type: "bar",
orientation: "h",
x: importance_values,
y: param_names,
text: importance_values.map((v) => String(v.toFixed(2))),
textposition: "outside",
hovertemplate: param_hover_templates,
marker: {
color: param_colors,
},
},
]

plotly.react(plotDomId, plotData, layout)
}
15 changes: 15 additions & 0 deletions optuna_dashboard/static/components/StudyDetail.tsx
Expand Up @@ -20,6 +20,7 @@ import { Home, Cached } from "@material-ui/icons"

import { DataGridColumn, DataGrid } from "./DataGrid"
import { GraphParallelCoordinate } from "./GraphParallelCoordinate"
import { HyperparameterImportances } from "./HyperparameterImportances"
import { GraphIntermediateValues } from "./GraphIntermediateValues"
import { GraphSlice } from "./GraphSlice"
import { GraphHistory } from "./GraphHistory"
Expand Down Expand Up @@ -197,6 +198,20 @@ export const StudyDetail: FC = () => {
</Grid>
</Grid>
) : null}
{studyDetail !== null && isSingleObjectiveStudy(studyDetail) ? (
<Grid container direction="row">
<Grid item xs={6}>
<Card className={classes.card}>
<CardContent>
<HyperparameterImportances
studyId={studyIdNumber}
numOfTrials={trials.length}
/>
</CardContent>
</Card>
</Grid>
</Grid>
) : null}
{studyDetail !== null ? (
<Card className={classes.card}>
<CardContent>
Expand Down
18 changes: 18 additions & 0 deletions optuna_dashboard/static/types/index.d.ts
Expand Up @@ -9,6 +9,13 @@ declare const URL_PREFIX: string

type TrialState = "Running" | "Complete" | "Pruned" | "Fail" | "Waiting"
type StudyDirection = "maximize" | "minimize" | "not_set"
type Distribution =
| "UniformDistribution"
| "LogUniformDistribution"
| "DiscreteUniformDistribution"
| "IntUniformDistribution"
| "IntLogUniformDistribution"
| "CategoricalDistribution"

declare interface TrialIntermediateValue {
step: number
Expand All @@ -20,6 +27,12 @@ declare interface TrialParam {
value: string
}

declare interface ParamImportance {
name: string
importance: number
distribution: Distribution
}

declare interface Attribute {
key: string
value: string
Expand Down Expand Up @@ -60,3 +73,8 @@ declare interface StudyDetail {
declare interface StudyDetails {
[study_id: string]: StudyDetail
}

declare interface ParamImportances {
target_name: string
param_importances: ParamImportance[]
}
1 change: 1 addition & 0 deletions setup.cfg
Expand Up @@ -30,6 +30,7 @@ install_requires =
optuna>=2.4
bottle
typing-extensions;python_version<'3.8'
scikit-learn

[options.extras_require]
lint =
Expand Down

0 comments on commit 03e176b

Please sign in to comment.