Skip to content
Permalink
Browse files

native tests pass

  • Loading branch information
sytelus committed May 22, 2019
1 parent 32027d2 commit b8c40f3636145b43b802f2fab662f2e5896f7bed
@@ -13,6 +13,7 @@
from .mpl.image_plot import ImagePlot
from .mpl.histogram import Histogram
from .mpl.bar_plot import BarPlot
from .mpl.pie_chart import PieChart
from .visualizer import Visualizer

from .stream import Stream
File renamed without changes.
File renamed without changes.
@@ -42,7 +42,7 @@ def probabilities2classes(probs, topk=5):
for p, c in zip(top_probs[0][0].detach().numpy(), top_probs[1][0].detach().numpy()))

class ImagenetLabels:
def __init__(self, json_path='../../data/imagenet_class_index.json'):
def __init__(self, json_path='imagenet_class_index.json'):
self._idx2label = []
self._idx2cls = []
self._cls2label = {}
@@ -64,7 +64,7 @@ def _get_stream_code(self, event_name, stream_name, stream_index, stream_info)->
lines = []

stream_identifier = 's'+str(stream_index)
lines.append("{} = client.open_stream(stream_name='{}')".format(stream_identifier, stream_name))
lines.append("{} = client.open_stream(name='{}')".format(stream_identifier, stream_name))

vis_identifier = 'v'+str(stream_index)
vis_args_strs = ['stream={}'.format(stream_identifier)]
@@ -52,4 +52,4 @@ def logits2probabilities(logits):
return F.softmax(logits, dim=1)

def tensor2numpy(t):
return t.data().cpu().numpy()
return t.data.cpu().numpy()
@@ -43,7 +43,7 @@ def __init__(self, stream:Stream, vis_type:str=None, host:'Visualizer'=None,

self._host_base.subscribe(stream, show=False, clear_after_end=clear_after_end, clear_after_each=clear_after_each,
history_len=history_len, dim_history=dim_history, opacity=opacity,
only_summary=only_summary if 'summary' != vis_type else True,
only_summary=only_summary if vis_type is None or 'summary' != vis_type else True,
separate_yaxis=separate_yaxis, xtitle=xtitle, ytitle=ytitle, ztitle=ztitle, color=color,
xrange=xrange, yrange=yrange, zrange=zrange,
draw_line=draw_line if vis_type is not None and 'scatter' in vis_type else True,
@@ -63,7 +63,7 @@ def _get_vis_base(self, vis_type, cell:VisBase.widgets.Box, title, hover_images=
if vis_type is None or vis_type in ['line',
'mpl-line', 'mpl-line3d', 'mpl-scatter3d', 'mpl-scatter']:
return mpl.line_plot.LinePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height,
is_3d=vis_type.endswith('3d'), **vis_args)
is_3d=vis_type is not None and vis_type.endswith('3d'), **vis_args)
if vis_type in ['image', 'mpl-image']:
return mpl.image_plot.ImagePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args)
if vis_type in ['bar', 'bar3d']:
@@ -87,7 +87,7 @@ def _clisrv_callback(self, clisrv, clisrv_req): # pylint: disable=unused-argumen
# request = create stream
if clisrv_req.req_type == CliSrvReqTypes.create_stream:
stream_req = clisrv_req.req_data
self.create_stream(stream_name=stream_req.stream_name, devices=stream_req.devices,
self.create_stream(name=stream_req.stream_name, devices=stream_req.devices,
event_name=stream_req.event_name, expr=stream_req.expr, throttle=stream_req.throttle,
vis_args=stream_req.vis_args)
return None # ignore return as we can't send back stream obj
@@ -73,14 +73,14 @@ def create_stream(self, name:str=None, devices:Sequence[str]=None, event_name:st
self._zmq_srvmgmt_sub.add_stream_req(stream_req)

if stream_req.devices is not None:
stream = self.open_stream(stream_name=stream_req.stream_name, devices=stream_req.devices)
stream = self.open_stream(name=stream_req.stream_name, devices=stream_req.devices)
else: # we cannot return remote streams that are not backed by a device
stream = None
return stream

# override to set devices default to tcp
def open_stream(self, name:str=None, devices:Sequence[str]=None)->Stream: # overriden
return super(WatcherClient, self).open_stream(stream_name=name, devices=devices)
return super(WatcherClient, self).open_stream(name=name, devices=devices)


# override to send request to server
@@ -4,8 +4,8 @@
def main():
w = tw.Watcher()
s1 = w.create_stream()
s2 = w.create_stream(stream_name='accuracy', vis_args=tw.VisArgs(vis_type='line', xtitle='X-Axis', clear_after_each=False, history_len=2))
s3 = w.create_stream(stream_name='loss', expr='lambda d:d.loss')
s2 = w.create_stream(name='accuracy', vis_args=tw.VisArgs(vis_type='line', xtitle='X-Axis', clear_after_each=False, history_len=2))
s3 = w.create_stream(name='loss', expr='lambda d:d.loss')
w.make_notebook()

main()
File renamed without changes.
@@ -27,22 +27,22 @@ def show_find_lr():

utils.wait_key()

def plot_grads():
def plot_grads_plotly():
train_cli = tw.WatcherClient()
grads = train_cli.create_stream(event_name='batch',
expr='lambda d:agg_params(d.model, lambda p: p.grad.abs().mean().item())', throttle=1)
expr='lambda d:grads_abs_mean(d.model)', throttle=1)
p = tw.plotly.line_plot.LinePlot('Demo')
p.subscribe(grads, xtitle='Epoch', ytitle='Gradients', history_len=30, new_on_eval=True)
p.subscribe(grads, xtitle='Layer', ytitle='Gradients', history_len=30, new_on_eval=True)
utils.wait_key()


def plot_grads1():
def plot_grads():
train_cli = tw.WatcherClient()

grads = train_cli.create_stream(event_name='batch',
expr='lambda d:agg_params(d.model, lambda p: p.grad.abs().mean().item())', throttle=1)
expr='lambda d:grads_abs_mean(d.model)', throttle=1)
grad_plot = tw.LinePlot()
grad_plot.subscribe(grads, xtitle='Epoch', ytitle='Gradients', clear_after_each=1, history_len=40, dim_history=True)
grad_plot.subscribe(grads, xtitle='Layer', ytitle='Gradients', clear_after_each=1, history_len=40, dim_history=True)
grad_plot.show()

tw.plt_loop()
@@ -51,9 +51,9 @@ def plot_weight():
train_cli = tw.WatcherClient()

params = train_cli.create_stream(event_name='batch',
expr='lambda d:agg_params(d.model, lambda p: p.abs().mean().item())', throttle=1)
expr='lambda d:weights_abs_mean(d.model)', throttle=1)
params_plot = tw.LinePlot()
params_plot.subscribe(params, xtitle='Epoch', ytitle='avg |params|', clear_after_each=1, history_len=40, dim_history=True)
params_plot.subscribe(params, xtitle='Layer', ytitle='avg |params|', clear_after_each=1, history_len=40, dim_history=True)
params_plot.show()

tw.plt_loop()
@@ -89,22 +89,22 @@ def batch_stats():
# vis=train_loss, vis_type='mpl-line')

train_loss.show()
tw.image_utils.plt_loop()
tw.plt_loop()

def text_stats():
train_cli = tw.WatcherClient()
stream = train_cli.create_stream(event_name="batch",
expr='lambda d:(d.x, d.metrics.batch_loss)')
expr='lambda d:(d.metrics.epoch_index, d.metrics.batch_loss)')

trl = tw.Visualizer(stream, vis_type=None)
trl = tw.Visualizer(stream, vis_type='text')
trl.show()
input('Paused...')



#epoch_stats()
#plot_weight()
#plot_grads1()
img_in_class()
#text_stats()
#plot_grads()
#img_in_class()
text_stats()
#batch_stats()
@@ -2,7 +2,7 @@
from tensorwatch import image_utils, imagenet_utils, pytorch_utils

model = pytorch_utils.get_model('resnet50')
raw_input, input, target_class = pytorch_utils.image_class2tensor('../data/dogs.png', 240, #'../data/elephant.png', 101,
raw_input, input, target_class = pytorch_utils.image_class2tensor('../data/test_images/dogs.png', 240, #'../data/elephant.png', 101,
image_transform=imagenet_utils.get_image_transform(), image_convert_mode='RGB')

results = saliency.get_image_saliency_results(model, raw_input, input, target_class)
@@ -20,7 +20,7 @@ def show_mpl():
def show_text():
cli = tw.WatcherClient()
s1 = cli.create_stream(expr='lambda v:(v.i, v.sum)')
text = tw.Visualizer(s1)
text = tw.Visualizer(s1, vis_type='text')
text.show()
input('Waiting')

File renamed without changes.
@@ -5,7 +5,7 @@
<ProjectGuid>9a7fe67e-93f0-42b5-b58f-77320fc639e4</ProjectGuid>
<ProjectHome>
</ProjectHome>
<StartupFile>files\file_stream.py</StartupFile>
<StartupFile>mnist\cli_mnist.py</StartupFile>
<SearchPath>
</SearchPath>
<WorkingDirectory>.</WorkingDirectory>
@@ -24,6 +24,7 @@
<EnableUnmanagedDebugging>false</EnableUnmanagedDebugging>
</PropertyGroup>
<ItemGroup>
<Compile Include="deps\live_graph.py" />
<Compile Include="visualizations\arr_img_plot.py">
<SubType>Code</SubType>
</Compile>
@@ -70,7 +71,7 @@
<Compile Include="visualizations\histogram.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="visualizations\line_plot.py">
<Compile Include="visualizations\line3d_plot.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="components\notebook_maker.py">
@@ -98,7 +99,7 @@
<Compile Include="components\stream.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="zmq\zmq_stream_pub.py">
<Compile Include="zmq\zmq_stream.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="zmq\zmq_watcher_client.py">
@@ -107,11 +108,11 @@
<Compile Include="zmq\zmq_watcher_server.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="zmq\zmq_stream_sub.py">
<Compile Include="zmq\zmq_sub.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="simple_log\srv_ij.py" />
<Compile Include="zmq\zmq_srv.py">
<Compile Include="zmq\zmq_pub.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="deps\thread.py">
@@ -1,6 +1,8 @@
import tensorwatch as tw
import random, time

# TODO: resolve problem with Axis3D?

def static_line3d():
w = tw.Watcher()
s = w.create_stream()
@@ -1,5 +1,5 @@
from tensorwatch.watcher_base import WatcherBase
from tensorwatch import LinePlot
from tensorwatch.mpl.line_plot import LinePlot
from tensorwatch.image_utils import plt_loop
from tensorwatch.stream import Stream
from tensorwatch.lv_types import StreamItem
@@ -6,7 +6,7 @@
utils.set_debug_verbosity(10)

def clisrv_callback(clisrv, msg):
print(msg)
print('from clisrv', msg)

stream = ZmqWrapper.Publication(port = 40859)
clisrv = ZmqWrapper.ClientServer(40860, True, clisrv_callback)
@@ -1,13 +1,11 @@
from tensorwatch.watcher_base import WatcherBase
from tensorwatch.stream import Stream
from tensorwatch.zmq_stream import ZmqStream

def main():
watcher = WatcherBase()
zmq_pub = ZmqStream(for_write=True, stream_name = 'ZmqPub', console_debug=True)
zmq_sub = ZmqStream(for_write=False, stream_name = 'ZmqSub', console_debug=True)

stream = watcher.create_stream(expr='lambda vars:vars.x**2')

zmq_pub = ZmqStream(for_write=True, stream_name = 'ZmqPub', console_debug=True)
zmq_pub.subscribe(stream)

for i in range(5):
@@ -3,11 +3,15 @@
from tensorwatch.zmq_wrapper import ZmqWrapper
from tensorwatch import utils

def on_event(obj):
print(obj)
class A:
def on_event(self, obj):
print(obj)

a = A()


utils.set_debug_verbosity(10)
sub = ZmqWrapper.Subscription(40859, "Topic1", on_event)
sub = ZmqWrapper.Subscription(40859, "Topic1", a.on_event)
print("subscriber is waiting")

clisrv = ZmqWrapper.ClientServer(40860, False)

0 comments on commit b8c40f3

Please sign in to comment.
You can’t perform that action at this time.