Skip to content
This repository has been archived by the owner on Mar 12, 2024. It is now read-only.

Commit

Permalink
Fix encoding in hiplot-render for windows users (#220)
Browse files Browse the repository at this point in the history
Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
  • Loading branch information
danthe3rd and danthe3rd committed Nov 2, 2021
1 parent eaa37fc commit c7f0cd4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
23 changes: 13 additions & 10 deletions hiplot/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# LICENSE file in the root directory of this source tree.

import csv
import enum
import uuid
import json
import codecs
import warnings
from abc import ABCMeta, abstractmethod
from enum import Enum
Expand All @@ -18,6 +18,8 @@
from .streamlit_helpers import ExperimentStreamlitComponent
import optuna

TextWriterIO = tp.Union[tp.IO[str], codecs.StreamWriter]

DisplayableType = tp.Union[bool, int, float, str]


Expand Down Expand Up @@ -345,7 +347,7 @@ def get_experiment():
To render an experiment to HTML, use `experiment.to_html(file_name)` or `html_page = experiment.to_html()`""")
return streamlit_helpers.ExperimentStreamlitComponent(self, key=key, ret=ret)

def to_html(self, file: tp.Optional[tp.Union[Path, str, tp.IO[str]]] = None, **kwargs: tp.Any) -> str:
def to_html(self, file: tp.Optional[tp.Union[Path, str, TextWriterIO]] = None, **kwargs: tp.Any) -> str:
"""
Returns the content of a standalone .html file that displays this experiment
without any dependency to HiPlot server or static files.
Expand All @@ -368,7 +370,7 @@ def to_html(self, file: tp.Optional[tp.Union[Path, str, tp.IO[str]]] = None, **k
file.write(html)
return html

def to_csv(self, file: tp.Union[Path, str, tp.IO[str]]) -> None:
def to_csv(self, file: tp.Union[Path, str, TextWriterIO]) -> None:
"""
Dumps this Experiment as a .csv file.
Information about display_data, parameters definition will be lost.
Expand All @@ -381,7 +383,7 @@ def to_csv(self, file: tp.Union[Path, str, tp.IO[str]]) -> None:
else:
return self._to_csv(file)

def _to_csv(self, fh: tp.IO[str]) -> None:
def _to_csv(self, fh: TextWriterIO) -> None:
fieldnames: tp.Set[str] = set()
for dp in self.datapoints:
for f in dp.values.keys():
Expand Down Expand Up @@ -512,20 +514,23 @@ def from_optuna(study: "optuna.study.Study") -> "Experiment": # No type hint to
:param study: Optuna Study
"""


# Create a list of dictionary objects using study trials
# All parameters are taken using params.copy()
# pylint: disable=redefined-outer-name
import optuna

hyper_opt_data = []
for each_trial in study.get_trials(states=(optuna.trial.TrialState.COMPLETE, )):
trial_params = {}
if not each_trial.values: # This checks if the trial was fully completed - the value will be None if the trial was interrupted halfway (e.g. via KeyboardInterrupt)
# This checks if the trial was fully completed
# the value will be None if the trial was interrupted halfway (e.g. via KeyboardInterrupt)
if not each_trial.values:
continue
num_objectives = len(each_trial.values)

if num_objectives == 1:
trial_params["value"] = each_trial.value # name = value, as it could be RMSE / accuracy, or any value that the user selects for tuning
# name = value, as it could be RMSE / accuracy, or any value that the user selects for tuning
trial_params["value"] = each_trial.value
else:
for objective_id, value in enumerate(each_trial.values):
trial_params[f"value_{objective_id}"] = value
Expand All @@ -537,8 +542,6 @@ def from_optuna(study: "optuna.study.Study") -> "Experiment": # No type hint to

return experiment



@staticmethod
def merge(xp_dict: tp.Dict[str, "Experiment"]) -> "Experiment":
"""
Expand Down
6 changes: 4 additions & 2 deletions hiplot/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import json
from typing import Any, Dict
from pathlib import Path
import codecs

from . import fetchers

Expand Down Expand Up @@ -100,10 +101,11 @@ def hiplot_render_main() -> int:

exp = fetchers.load_xp_with_fetchers(fetchers.get_fetchers(args.fetchers), args.experiment_uri)
exp.validate()
stdout_writer = codecs.getwriter("utf-8")(sys.stdout.buffer)
if args.format == 'csv':
exp.to_csv(sys.stdout)
exp.to_csv(stdout_writer)
elif args.format == 'html':
exp.to_html(sys.stdout)
exp.to_html(stdout_writer)
else:
assert False, args.format
return 0

0 comments on commit c7f0cd4

Please sign in to comment.