Create library for converting pytorch to caffe2 and examples #69
Conversation
onnx_caffe2/pytorch_caffe2.py
Outdated
log = logging.getLogger(__name__) | ||
|
||
|
||
def run_caffe2_model(init_net, predict_net, inputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already has one here: https://github.com/onnx/onnx-caffe2/blob/master/onnx_caffe2/helper.py#L56
onnx_caffe2/pytorch_caffe2.py
Outdated
|
||
|
||
def run_caffe2_benchmark(init_net, predict_net, warmup_iters, main_iters, layer_details): | ||
workspace.ResetWorkspace() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't directly use caffe2.python.workspace
, use onnx_caffe2.workspace.Workspace
instead.
onnx_caffe2/pytorch_caffe2.py
Outdated
return init_net, predict_net | ||
|
||
|
||
def load_caffe2_net(file): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe move this into helper.py
onnx_caffe2/pytorch_caffe2.py
Outdated
return net | ||
|
||
|
||
def save_caffe2_net(net, file, output_txt=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
helper.py
onnx_caffe2/pytorch_caffe2.py
Outdated
pytorch_out, expected_decimal) | ||
log.info("The converted Caffe2 model achieves {}-decimal precision." | ||
.format(expected_decimal)) | ||
if not compare_performance: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Split this into a separate benchmark/profile function
onnx_caffe2/pytorch_caffe2.py
Outdated
return init_net, predict_net | ||
|
||
log.info("Starting benchmarking PyTorch.") | ||
for _i in range(warmup_iters): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be put these into a run_pytorch_benchmark
function?
Refactoring the code, will move most of the functions to onnx_caffe2/helper.py. And move the example to the tutorial repo. |
onnx_caffe2/helper.py
Outdated
|
||
ws.RunNetOnce(predict_net) | ||
|
||
output_names = predict_net.external_output | ||
output_values = [ws.FetchBlob(name) for name in output_names] | ||
return ws, namedtupledict('Outputs', output_names)(*output_values) | ||
|
||
|
||
def name_inputs(onnx_model, inputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should have the names of blobs in predict_net.external_input. So maybe just pass the list directly to c2_native_run_net?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
init_net contains the fake input. So the assertion len(uninitialized) == len(inputs) fails. So it's better to pass a dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I meant is to accept the list directly in c2_native_run_net() and put this code there
onnx_caffe2/helper.py
Outdated
def benchmark_pytorch_model(model, inputs, training=False, warmup_iters=3, | ||
main_iters=10, verbose=False): | ||
for _i in range(warmup_iters): | ||
ts = time.time() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need for time in warmup
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch.
onnx_caffe2/helper.py
Outdated
ts = time.time() | ||
model(*inputs) | ||
te = time.time() | ||
total_pytorch_time = te - ts + total_pytorch_time |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+= :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. :-)
onnx_caffe2/helper.py
Outdated
model(*inputs) | ||
te = time.time() | ||
total_pytorch_time = te - ts + total_pytorch_time | ||
log.info("The PyTorch model execution time per iter is {} milliseconds, " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also return time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point!
examples/pytorch_to_caffe2.py
Outdated
init_net = load_caffe2_net(init_file) | ||
predict_net = load_caffe2_net(predict_file) | ||
|
||
# Prepare the inputs for Caffe2. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be handled by passing option to Caffe2Backend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, ship it!
onnx_caffe2/helper.py
Outdated
|
||
ws.RunNetOnce(predict_net) | ||
|
||
output_names = predict_net.external_output | ||
output_values = [ws.FetchBlob(name) for name in output_names] | ||
return ws, namedtupledict('Outputs', output_names)(*output_values) | ||
|
||
|
||
def name_inputs(onnx_model, inputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I meant is to accept the list directly in c2_native_run_net() and put this code there
onnx_caffe2/helper.py
Outdated
ws.RunNetOnce(init_net) | ||
ws.CreateNet(predict_net) | ||
results = ws.BenchmarkNet(predict_net.name, warmup_iters, main_iters, layer_details) | ||
return results[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add del ws
to free up memory explicitly
Add an example python file to demonstrate how to convert PyTorch models to Caffe2.
06d305e
to
60baedf
Compare
Base on this example, we can create notebooks which explain how to compare the results and performance.