Skip to content

Commit

Permalink
Comet integration improvements (#658)
Browse files Browse the repository at this point in the history
* add methods for specifically logging PR curve data

* update writer to use comet pr_data logging methods

* fix raw histogram data logging

* clean up numpy imports

* clean up docstrings

* more docstring clean up

* fix lint issue in comet_utils
  • Loading branch information
DN6 committed Feb 19, 2022
1 parent 30e9024 commit 74d8a35
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 28 deletions.
81 changes: 81 additions & 0 deletions tensorboardX/comet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import functools
from io import BytesIO
from google.protobuf.json_format import MessageToJson
import numpy as np
from .summary import _clean_tag
try:
Expand Down Expand Up @@ -155,6 +156,22 @@ def log_histogram(self, values, name=None, step=None, epoch=None,
epoch, metadata,
**kwargs)

@_requiresComet
def log_histogram_raw(self, tag, summary, step=None):
"""Log Raw Histogram Data to Comet as an Asset.
Args:
tag: Name given to the logged asset
summary: TensorboardX Summary protocol buffer with histogram data
step: The Global Step for this experiment run. Defaults to None.
"""

histogram_proto = summary.value[0].histo
histogram_raw_data = MessageToJson(histogram_proto)
histogram_raw_data['name'] = tag

self.log_asset_data(data=histogram_raw_data, name=tag, step=step)

@_requiresComet
def log_curve(self, name, x, y, overwrite=False, step=None):
"""Log timeseries data.
Expand Down Expand Up @@ -300,3 +317,67 @@ def log_raw_figure(self, tag, asset_type, step=None, **kwargs):
file_json = kwargs
file_json['asset_type'] = asset_type
self.log_asset_data(file_json, tag, step=step)

@_requiresComet
def log_pr_data(self, tag, summary, num_thresholds, step=None):
"""Logs a Precision-Recall Curve Data as an asset.
Args:
tag: An identifier for the PR curve
summary: TensorboardX Summary protocol buffer.
step: step value to record
"""
tensor_proto = summary.value[0].tensor
shape = [d.size for d in tensor_proto.tensor_shape.dim]

values = np.fromiter(tensor_proto.float_val, dtype=np.float32).reshape(shape)
thresholds = [1.0 / num_thresholds * i for i in range(num_thresholds)]
tp, fp, tn, fn, precision, recall = map(lambda x: x.flatten().tolist(), np.vsplit(values, values.shape[0]))

pr_data = {
'TP': tp,
'FP': fp,
'TN': tn,
'FN': fn,
'precision': precision,
'recall': recall,
'thresholds': thresholds,
'name': tag,
}

self.log_asset_data(pr_data, name=tag, step=step)

@_requiresComet
def log_pr_raw_data(self, tag, true_positive_counts,
false_positive_counts, true_negative_counts,
false_negative_counts, precision, recall,
num_thresholds, weights, step=None):
"""Logs a Precision-Recall Curve Data as an asset.
Args:
tag: An identifier for the PR curve
summary: TensorboardX Summary protocol buffer.
step: step value to record
"""
thresholds = [1.0 / num_thresholds * i for i in range(num_thresholds)]
tp, fp, tn, fn, precision, recall = map(lambda x: x.flatten().tolist(), [
true_positive_counts,
false_positive_counts,
true_negative_counts,
false_negative_counts,
precision,
recall])

pr_data = {
'TP': tp,
'FP': fp,
'TN': tn,
'FN': fn,
'precision': precision,
'recall': recall,
'thresholds': thresholds,
'weights': weights,
'name': tag,
}

self.log_asset_data(pr_data, name=tag, step=step)
52 changes: 24 additions & 28 deletions tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,25 +615,19 @@ def add_histogram_raw(
"""
if len(bucket_limits) != len(bucket_counts):
raise ValueError('len(bucket_limits) != len(bucket_counts), see the document.')
summary = histogram_raw(tag,
min,
max,
num,
sum,
sum_squares,
bucket_limits,
bucket_counts)
self._get_file_writer().add_summary(
histogram_raw(tag,
min,
max,
num,
sum,
sum_squares,
bucket_limits,
bucket_counts),
summary,
global_step,
walltime)
self._get_comet_logger().log_raw_figure(tag, 'histogram_raw', global_step,
min=min,
max=max,
num=num,
sum=sum,
sum_squares=sum_squares,
bucket_limits=bucket_limits,
bucket_counts=bucket_counts)
self._get_comet_logger().log_histogram_raw(tag, summary, step=global_step)

def add_image(
self,
Expand Down Expand Up @@ -1136,10 +1130,13 @@ def add_pr_curve(
"""
from .x2num import make_np
labels, predictions = make_np(labels), make_np(predictions)

summary = pr_curve(tag, labels, predictions, num_thresholds, weights)
self._get_file_writer().add_summary(
pr_curve(tag, labels, predictions, num_thresholds, weights),
summary,
global_step, walltime)
self._get_comet_logger().log_curve(tag, labels, predictions, step=global_step)

self._get_comet_logger().log_pr_data(tag, summary, num_thresholds, step=global_step)

def add_pr_curve_raw(
self,
Expand Down Expand Up @@ -1176,16 +1173,15 @@ def add_pr_curve_raw(
weights),
global_step,
walltime)
self._get_comet_logger().log_raw_figure(tag, 'pr_curve_raw', global_step,
true_positive_counts=true_positive_counts,
false_positive_counts=false_positive_counts,
true_negative_counts=true_negative_counts,
false_negative_counts=false_negative_counts,
precision=precision,
recall=recall,
num_thresholds=num_thresholds,
weights=weights,
walltime=walltime)
self._get_comet_logger().log_pr_raw_data(tag, step=global_step,
true_positive_counts=true_positive_counts,
false_positive_counts=false_positive_counts,
true_negative_counts=true_negative_counts,
false_negative_counts=false_negative_counts,
precision=precision,
recall=recall,
num_thresholds=num_thresholds,
weights=weights)

def add_custom_scalars_multilinechart(
self,
Expand Down

0 comments on commit 74d8a35

Please sign in to comment.