Skip to content

Commit

Permalink
Adds timing info for model in GPU (#527)
Browse files Browse the repository at this point in the history
  • Loading branch information
lanpa committed Oct 28, 2019
1 parent 08e4703 commit 656037e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 19 deletions.
60 changes: 44 additions & 16 deletions tensorboardX/pytorch_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,13 @@ def find_time_for(node_name):
for i, n in enumerate(profile_result):
if n.key == node_name:
profile_result.pop(i)
time_we_want = n.cpu_time_total
return int(time_we_want)
time_we_want_cpu = n.cpu_time_total
time_we_want_cuda = n.cuda_time_total

return int(time_we_want_cpu), int(time_we_want_cuda)
return None, None

should_show_warning = False
for v in self.nodes_io.values():
nodes.append(node_proto(v.debugName,
input=v.inputs,
Expand All @@ -208,19 +212,29 @@ def find_time_for(node_name):
# prim:: and Parameter
if 'aten' in v.kind and self.profile_result is not None:
opname = v.kind.split('::')[1]
exe_time = find_time_for(opname)
node_stats.append(
NodeExecStats(node_name=v.debugName,
all_start_micros=int(time.time() * 1e7),
all_end_rel_micros=exe_time))
exe_time_cpu, exe_time_cuda = find_time_for(opname)
if exe_time_cpu is not None:
total_time = exe_time_cpu + exe_time_cuda

# assume that the operation will not executed on both device simultaneously.
if total_time - max(exe_time_cpu, exe_time_cuda) > 0.01:
should_show_warning = True

node_stats.append(
NodeExecStats(node_name=v.debugName,
all_start_micros=int(time.time() * 1e7),
all_end_rel_micros=total_time))

if v.tensor_size and len(v.tensor_size) > 0: # assume data is float32, only parameter is counted
node_stats.append(
NodeExecStats(node_name=v.debugName,
all_start_micros=int(time.time() * 1e7),
all_end_rel_micros=42,
memory=[AllocatorMemoryUsed(allocator_name="cpu",
memory=[AllocatorMemoryUsed(allocator_name="unknown",
total_bytes=int(np.prod(v.tensor_size)) * 4)]))
if should_show_warning:
logging.warning('time cost for node is the sum of CPU + GPU.')

return nodes, node_stats


Expand Down Expand Up @@ -281,8 +295,18 @@ def parse(graph, args=None, profile_result=None):
nodes_py.populate_namespace_from_OP_to_IO()
return nodes_py.to_proto()

def recursive_to_cuda(x):
"""
Recursively convert tensors in a tuple or list to GPU tensor.
"""
import torch

if isinstance(x, torch.Tensor):
return x.cuda()
else:
return [recursive_to_cuda(_x) for _x in x]

def graph(model, args, verbose=False, **kwargs):
def graph(model, args, verbose=False, use_cuda=False, **kwargs):
"""
This method processes a PyTorch model and produces a `GraphDef` proto
that can be logged to TensorBoard.
Expand Down Expand Up @@ -313,12 +337,11 @@ def graph(model, args, verbose=False, **kwargs):
# TensorBoard logged data.

try:
with torch.autograd.profiler.profile(record_shapes=True) as prof:
if len(args) == 1 and isinstance(args, tuple) or isinstance(args, list):
args = args[0]
result = model(args)
else:
result = model(*args)
if use_cuda:
model.cuda()
args = recursive_to_cuda(args)
with torch.autograd.profiler.profile(record_shapes=True, use_cuda=use_cuda) as prof:
result = model(*args)

except RuntimeError as e:
print('profiler execution failed')
Expand All @@ -337,6 +360,11 @@ def graph(model, args, verbose=False, **kwargs):
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts
# and
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0",

if use_cuda:
device = "/device:GPU:0"
else:
device = "/device:CPU:0"
stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device=device,
node_stats=node_stats)]))
return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats
6 changes: 3 additions & 3 deletions tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def add_graph(self, graph_profile, walltime=None):
self.add_event(event, None, walltime)

trm = event_pb2.TaggedRunMetadata(
tag='step1', run_metadata=stepstats.SerializeToString())
tag='profiler', run_metadata=stepstats.SerializeToString())
event = event_pb2.Event(tagged_run_metadata=trm)
self.add_event(event, None, walltime)

Expand Down Expand Up @@ -767,7 +767,7 @@ def add_openvino_graph(self, xmlname):
"""
self._get_file_writer().add_openvino_graph(load_openvino_graph(xmlname))

def add_graph(self, model, input_to_model=None, verbose=False, **kwargs):
def add_graph(self, model, input_to_model=None, verbose=False, profile_with_cuda=False, **kwargs):
# prohibit second call?
# no, let tensorboard handle it and show its warning message.
"""Add graph data to summary.
Expand Down Expand Up @@ -796,7 +796,7 @@ def add_graph(self, model, input_to_model=None, verbose=False, **kwargs):
if not hasattr(torch.autograd.Variable, 'grad_fn'):
print('add_graph() only supports PyTorch v0.2.')
return
self._get_file_writer().add_graph(graph(model, input_to_model, verbose, **kwargs))
self._get_file_writer().add_graph(graph(model, input_to_model, verbose, profile_with_cuda, **kwargs))
else:
# Caffe2 models do not have the 'forward' method
from caffe2.proto import caffe2_pb2
Expand Down

0 comments on commit 656037e

Please sign in to comment.