Skip to content

Commit

Permalink
Merge pull request #29 from ruotianluo/master
Browse files Browse the repository at this point in the history
improvement on #26
  • Loading branch information
lanpa committed Sep 23, 2017
2 parents 35a8188 + 2397e35 commit 7108322
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
10 changes: 9 additions & 1 deletion tensorboardX/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .src.summary_pb2 import SummaryMetadata
from .src.tensor_pb2 import TensorProto
from .src.tensor_shape_pb2 import TensorShapeProto
from .x2num import makenp

_INVALID_TAG_CHARACTERS = _re.compile(r'[^-/\w\.]')

Expand Down Expand Up @@ -84,6 +85,8 @@ def scalar(name, scalar, collections=None):
ValueError: If tensor has the wrong shape or type.
"""
name = _clean_tag(name)
scalar = makenp(scalar)
assert(scalar.squeeze().ndim==0), 'scalar should be 0D'
scalar = float(scalar)
return Summary(value=[Summary.Value(tag=name, simple_value=scalar)])

Expand All @@ -107,6 +110,7 @@ def histogram(name, values, bins, collections=None):
buffer.
"""
name = _clean_tag(name)
values = makenp(values)
hist = make_histogram(values.astype(float), bins)
return Summary(value=[Summary.Value(tag=name, histo=hist)])

Expand Down Expand Up @@ -151,7 +155,7 @@ def image(tag, tensor):
buffer.
"""
tag = _clean_tag(tag)
assert isinstance(tensor, np.ndarray), 'input tensor should be numpy.ndarray'
tensor = makenp(tensor, 'IMG')
tensor = tensor.astype(np.float32)
tensor = (tensor*255).astype(np.uint8)
image = make_image(tensor)
Expand All @@ -173,6 +177,10 @@ def make_image(tensor):
encoded_image_string=image_string)

def audio(tag, tensor, sample_rate=44100):
tensor = makenp(tensor)
tensor = tensor.squeeze()
assert(tensor.ndim==1), 'input tensor should be 1 dimensional.'

tensor_list = [int(32767.0*x) for x in tensor]
import io
import wave
Expand Down
9 changes: 0 additions & 9 deletions tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from .summary import scalar, histogram, image, audio, text
from .graph import graph
from .embedding import make_mat, make_sprite, make_tsv, append_pbtxt
from .x2num import makenp

class SummaryToEventTransformer(object):
"""Abstractly implements the SummaryWriter API.
Expand Down Expand Up @@ -248,8 +247,6 @@ def add_scalar(self, tag, scalar_value, global_step=None):
global_step (int): Global step value to record
"""
scalar_value = makenp(scalar_value)
assert(scalar_value.squeeze().ndim==0), 'input of add_scalar should be 0D'
self.file_writer.add_summary(scalar(tag, scalar_value), global_step)

def add_histogram(self, tag, values, global_step=None, bins='tensorflow'):
Expand All @@ -264,7 +261,6 @@ def add_histogram(self, tag, values, global_step=None, bins='tensorflow'):
"""
if bins=='tensorflow':
bins = self.default_bins
values = makenp(values)
self.file_writer.add_summary(histogram(tag, values, bins), global_step)

def add_image(self, tag, img_tensor, global_step=None):
Expand All @@ -277,7 +273,6 @@ def add_image(self, tag, img_tensor, global_step=None):
Shape:
img_tensor: :math:`(3, H, W)`. Use ``torchvision.utils.make_grid()`` to prepare it is a good idea.
"""
img_tensor = makenp(img_tensor, 'IMG')
self.file_writer.add_summary(image(tag, img_tensor), global_step)
def add_audio(self, tag, snd_tensor, global_step=None):
"""Add audio data to summary.
Expand All @@ -290,10 +285,6 @@ def add_audio(self, tag, snd_tensor, global_step=None):
Shape:
snd_tensor: :math:`(1, L)`. The values should between [-1, 1]. The sample rate is currently fixed at 44100 KHz.
"""
snd_tensor = makenp(snd_tensor)
snd_tensor = snd_tensor.squeeze()
assert(snd_tensor.ndim==1), 'input tensor should be 1 dimensional.'

self.file_writer.add_summary(audio(tag, snd_tensor), global_step)
def add_text(self, tag, text_string, global_step=None):
"""Add text data to summary.
Expand Down

0 comments on commit 7108322

Please sign in to comment.