Skip to content

Commit

Permalink
fix #639 (#642)
Browse files Browse the repository at this point in the history
* fix #639

* fix test dependency
  • Loading branch information
lanpa committed Sep 12, 2021
1 parent 4f13678 commit 054f1f3
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 50 deletions.
2 changes: 2 additions & 0 deletions examples/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,5 @@
writer.add_video('video_1_fps', vid_tensor=vid, fps=1)

writer.close()

writer.add_scalar('implicit reopen writer', 100, 0)
8 changes: 5 additions & 3 deletions examples/demo_comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

resnet18 = models.resnet18(False)
writer = SummaryWriter(comet_config={"disabled": False,
"workspace": 'myworkspace',
"project_name": 'tensorboardx',
"api_key": "xxxxxxxx"})
"workspace": 'tensorboardx-test',
"project_name": 'tbx-ci',
"api_key": "KOSKkXJ52qFZxkxYHlRJ7wOEk"})
sample_rate = 44100
freqs = [262, 294, 330, 349, 392, 440, 440, 440, 440, 440, 440]

Expand Down Expand Up @@ -97,3 +97,5 @@
writer.add_video('video_1_fps', vid_tensor=vid, fps=1)

writer.close()

writer.add_scalar('implicit reopen writer', 100, 0)
90 changes: 44 additions & 46 deletions tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,16 +291,15 @@ def __init__(
self._flush_secs = flush_secs
self._filename_suffix = filename_suffix
self._write_to_disk = write_to_disk
self._comet_config = comet_config
self._comet_logger = None
self.kwargs = kwargs

# Initialize the file writers, but they can be cleared out on close
# and recreated later as needed.
self.file_writer = self.all_writers = None
self._get_file_writer()

# Initialize the Comet Logger
self.comet_logger = CometLogger(comet_config)

# Create default bins for histograms, see generate_testdata.py in tensorflow/tensorboard
v = 1E-12
buckets = []
Expand Down Expand Up @@ -360,6 +359,12 @@ def _get_file_writer(self):
self.all_writers = {self.file_writer.get_logdir(): self.file_writer}
return self.file_writer

def _get_comet_logger(self):
"""Returns a comet logger instance. Recreates it if closed."""
if self._comet_logger is None:
self._comet_logger = CometLogger(self._comet_config)
return self._comet_logger

def add_hparams(
self,
hparam_dict: Dict[str, Union[bool, str, float, int]],
Expand Down Expand Up @@ -406,7 +411,7 @@ def add_hparams(
w_hp.file_writer.add_summary(sei)
for k, v in metric_dict.items():
w_hp.add_scalar(k, v, global_step)
self.comet_logger.log_parameters(hparam_dict, step=global_step)
self._get_comet_logger().log_parameters(hparam_dict, step=global_step)

def add_scalar(
self,
Expand Down Expand Up @@ -450,7 +455,7 @@ def add_scalar(
raise TypeError("Input value: \"{}\" is not a scalar".format(scalar_value))
self._get_file_writer().add_summary(
scalar(tag, scalar_value, display_name, summary_description), global_step, walltime)
self.comet_logger.log_metric(tag, display_name, scalar_value, global_step)
self._get_comet_logger().log_metric(tag, display_name, scalar_value, global_step)

def add_scalars(
self,
Expand Down Expand Up @@ -502,8 +507,7 @@ def add_scalars(
global_step, walltime)
self.__append_to_scalar_dict(
fw_tag, scalar_value, global_step, walltime)
self.comet_logger.log_metrics(tag_scalar_dict, main_tag,
step=global_step)
self._get_comet_logger().log_metrics(tag_scalar_dict, main_tag, step=global_step)

def export_scalars_to_json(self, path):
"""Exports to the given path an ASCII file containing all the scalars written
Expand Down Expand Up @@ -557,7 +561,7 @@ def add_histogram(
bins = self.default_bins
self._get_file_writer().add_summary(
histogram(tag, values, bins, max_bins=max_bins), global_step, walltime)
self.comet_logger.log_histogram(values, tag, global_step)
self._get_comet_logger().log_histogram(values, tag, global_step)

def add_histogram_raw(
self,
Expand Down Expand Up @@ -622,14 +626,14 @@ def add_histogram_raw(
bucket_counts),
global_step,
walltime)
self.comet_logger.log_raw_figure(tag, 'histogram_raw', global_step,
min=min,
max=max,
num=num,
sum=sum,
sum_squares=sum_squares,
bucket_limits=bucket_limits,
bucket_counts=bucket_counts)
self._get_comet_logger().log_raw_figure(tag, 'histogram_raw', global_step,
min=min,
max=max,
num=num,
sum=sum,
sum_squares=sum_squares,
bucket_limits=bucket_limits,
bucket_counts=bucket_counts)

def add_image(
self,
Expand Down Expand Up @@ -689,8 +693,7 @@ def add_image(
encoded_image_string = summary.value[0].image.encoded_image_string
self._get_file_writer().add_summary(
summary, global_step, walltime)
self.comet_logger.log_image_encoded(encoded_image_string, tag,
step=global_step)
self._get_comet_logger().log_image_encoded(encoded_image_string, tag, step=global_step)

def add_images(
self,
Expand Down Expand Up @@ -755,8 +758,7 @@ def add_images(
encoded_image_string = summary.value[0].image.encoded_image_string
self._get_file_writer().add_summary(
summary, global_step, walltime)
self.comet_logger.log_image_encoded(encoded_image_string, tag,
step=global_step)
self._get_comet_logger().log_image_encoded(encoded_image_string, tag, step=global_step)

def add_image_with_boxes(
self,
Expand Down Expand Up @@ -800,8 +802,7 @@ def add_image_with_boxes(
encoded_image_string = summary.value[0].image.encoded_image_string
self._get_file_writer().add_summary(
summary, global_step, walltime)
self.comet_logger.log_image_encoded(encoded_image_string, tag,
step=global_step)
self._get_comet_logger().log_image_encoded(encoded_image_string, tag, step=global_step)

def add_figure(
self,
Expand Down Expand Up @@ -853,8 +854,7 @@ def add_video(
encoded_image_string = summary.value[0].image.encoded_image_string
self._get_file_writer().add_summary(
summary, global_step, walltime)
self.comet_logger.log_image_encoded(encoded_image_string, tag,
step=global_step)
self._get_comet_logger().log_image_encoded(encoded_image_string, tag, step=global_step)

def add_audio(
self,
Expand All @@ -880,8 +880,7 @@ def add_audio(
snd_tensor = workspace.FetchBlob(snd_tensor)
self._get_file_writer().add_summary(
audio(tag, snd_tensor, sample_rate=sample_rate), global_step, walltime)
self.comet_logger.log_audio(snd_tensor, sample_rate, tag,
step=global_step)
self._get_comet_logger().log_audio(snd_tensor, sample_rate, tag, step=global_step)

def add_text(
self,
Expand All @@ -903,7 +902,7 @@ def add_text(
"""
self._get_file_writer().add_summary(
text(tag, text_string), global_step, walltime)
self.comet_logger.log_text(text_string, global_step)
self._get_comet_logger().log_text(text_string, global_step)

def add_onnx_graph(
self,
Expand All @@ -914,7 +913,7 @@ def add_onnx_graph(
onnx_model_file (string): The path to the onnx model.
"""
self._get_file_writer().add_onnx_graph(load_onnx_graph(onnx_model_file))
self.comet_logger.log_asset(onnx_model_file)
self._get_comet_logger().log_asset(onnx_model_file)

def add_openvino_graph(
self,
Expand All @@ -925,7 +924,7 @@ def add_openvino_graph(
xmlname (string): The path to the openvino model. (the xml file)
"""
self._get_file_writer().add_openvino_graph(load_openvino_graph(xmlname))
self.comet_logger.log_asset(xmlname)
self._get_comet_logger().log_asset(xmlname)

def add_graph(
self,
Expand Down Expand Up @@ -1096,7 +1095,7 @@ def add_embedding(
else:
template_filename = None

self.comet_logger.log_embedding(mat, metadata, label_img, template_filename=template_filename)
self._get_comet_logger().log_embedding(mat, metadata, label_img, template_filename=template_filename)

def add_pr_curve(
self,
Expand Down Expand Up @@ -1140,8 +1139,7 @@ def add_pr_curve(
self._get_file_writer().add_summary(
pr_curve(tag, labels, predictions, num_thresholds, weights),
global_step, walltime)
self.comet_logger.log_curve(tag, labels, predictions,
step=global_step)
self._get_comet_logger().log_curve(tag, labels, predictions, step=global_step)

def add_pr_curve_raw(
self,
Expand Down Expand Up @@ -1178,16 +1176,16 @@ def add_pr_curve_raw(
weights),
global_step,
walltime)
self.comet_logger.log_raw_figure(tag, 'pr_curve_raw', global_step,
true_positive_counts=true_positive_counts,
false_positive_counts=false_positive_counts,
true_negative_counts=true_negative_counts,
false_negative_counts=false_negative_counts,
precision=precision,
recall=recall,
num_thresholds=num_thresholds,
weights=weights,
walltime=walltime)
self._get_comet_logger().log_raw_figure(tag, 'pr_curve_raw', global_step,
true_positive_counts=true_positive_counts,
false_positive_counts=false_positive_counts,
true_negative_counts=true_negative_counts,
false_negative_counts=false_negative_counts,
precision=precision,
recall=recall,
num_thresholds=num_thresholds,
weights=weights,
walltime=walltime)

def add_custom_scalars_multilinechart(
self,
Expand Down Expand Up @@ -1289,8 +1287,8 @@ def add_mesh(
"""
self._get_file_writer().add_summary(mesh(tag, vertices, colors, faces, config_dict), global_step, walltime)
self.comet_logger.log_mesh(tag, vertices, colors, faces,
config_dict, global_step, walltime)
self._get_comet_logger().log_mesh(tag, vertices, colors, faces,
config_dict, global_step, walltime)

def close(self):
"""Close the current SummaryWriter. This call flushes the unfinished write operation.
Expand All @@ -1302,8 +1300,8 @@ def close(self):
writer.flush()
writer.close()
self.file_writer = self.all_writers = None
self.comet_logger.end()
self.comet_logger = None
self._get_comet_logger().end()
self._comet_logger = None

def flush(self):
"""Force the data in memory to be flushed to disk. Use this call if tensorboard does not update reqularly.
Expand Down
2 changes: 1 addition & 1 deletion test-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
pytest
torch
six
protobuf==3.8
protobuf==3.15
numpy==1.18
pillow
tensorboard
Expand Down

0 comments on commit 054f1f3

Please sign in to comment.