Skip to content

Commit

Permalink
fixes #478 (#481)
Browse files Browse the repository at this point in the history
* fixes #478
  • Loading branch information
lanpa committed Aug 5, 2019
1 parent b5ab572 commit 366bc8f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
13 changes: 12 additions & 1 deletion tensorboardX/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _draw_single_box(image, xmin, ymin, xmax, ymax, display_str, color='black',
def hparams(hparam_dict=None, metric_dict=None):
from tensorboardX.proto.plugin_hparams_pb2 import HParamsPluginData, SessionEndInfo, SessionStartInfo
from tensorboardX.proto.api_pb2 import Experiment, HParamInfo, MetricInfo, MetricName, Status
from six import string_types

PLUGIN_NAME = 'hparams'
PLUGIN_DATA_VERSION = 0
Expand All @@ -94,7 +95,17 @@ def hparams(hparam_dict=None, metric_dict=None):

ssi = SessionStartInfo()
for k, v in hparam_dict.items():
ssi.hparams[k].number_value = v
if isinstance(v, string_types):
ssi.hparams[k].string_value = v
continue

if isinstance(v, bool):
ssi.hparams[k].bool_value = v
continue

if not isinstance(v, int) or not isinstance(v, float):
v = make_np(v)[0]
ssi.hparams[k].number_value = v

content = HParamsPluginData(session_start_info=ssi, version=PLUGIN_DATA_VERSION)
smd = SummaryMetadata(plugin_data=SummaryMetadata.PluginData(plugin_name=PLUGIN_NAME,
Expand Down
5 changes: 4 additions & 1 deletion tests/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,7 @@ def test_hparams(self):
def test_hparams_smoke(self):
hp = {'lr': 0.1, 'bsize': 4}
mt = {'accuracy': 0.1, 'loss': 10}
summary.hparams(hp, mt)
summary.hparams(hp, mt)

hp = {'string': "1b", 'use magic': True}
summary.hparams(hp, mt)

0 comments on commit 366bc8f

Please sign in to comment.