Run PyTorch fast-neural-style (FNS) in web browsers using ONNX.js
This repository is for anyone interested to run PyTorch fast-neural-style example in web browsers. The performance is by no means optimal due to many workarounds for issues and limitations during the conversion process and different operator/layer support level between PyTorch and ONNX.js. But it serves the purpose to understand what it takes to go through the entire process.
It is an example for practicing and learning what it takes to make the PyTorch generated models portable to other deep learning frameworks. ONNX.js is set as the target deep learning framework as it's very new, hance still primitive.
This project is roughly based on the following open source projects:
- PyTorch v1.0.0 - fast-neural-style example for exporting ONNX model files (.onnx).
- ONNX.js v0.1.3 - add example for the javascript inference on the web.
This repo:
- Goto PyTorch fast-neural-style web web for a quick working web style transfer using ONNX.js.
- Goto PyTorch fast-neural-style web benchmark as a quick tool for performance on your browser.
How-to guides on tweaking PyTorch fast-neural-style repo:
- See PyTorch fast-neural-style for web for making it working with ONNX.js.
- See Making the PyTorch to ONNX.js conversion work if you are interested in more technical details.
The objective is simple:
PyTorch FNS example --> PyTorch model files (.pth) --> ONNX model files --> ONNX.js on web browsers
There are many style transfer implementations. PyTorch's fast-neural-style example is the most facinating one. Partly due to the way it is implemented provides a much finer style-transfered images. To run the inference in browser, the following 3 major steps are taken:
- Use PyTorch to train the model (this repository uses the 4 pre-trained models.)
- Use PyTorch's built-in ONNX export feature to export model files (.onnx)
- Load the ONNX model files (.onnx) and run inference using ONNX.js in web browsers.
Sounds straight forward!? Read on...
These steps may seem easy, but in practice it is way much more complicated.
The following were the few of major obstacles encountered during the process, just to give an idea on what are possible issues there may be:
-
Operator/layer support levels are very different.
- PyTorch nn layers - https://pytorch.org/docs/stable/nn.html
- PyTorch ONNX export operators - https://pytorch.org/docs/stable/onnx.html#supported-operators
- ONNX.js operators - https://github.com/Microsoft/onnxjs/blob/master/docs/operators.md
- PyTorch
nn.InstanceNorm2d()
is exported as ONNXInstanceNormalization()
, but not supported by ONNX.js. - PyTorch
nn.functional.interpolate()
is exported as ONNXUpsample()
, but not supported in ONNX.js. - At time of writing (Jan, 2019), PyTorch ONNX export at opset version 9 by default. ONNX.js at ONNX opset version 7.
-
Base tensor opset levels are different between PyTorch, PyTorch ONNX Export and ONNX.js
- PyTorch ONNX export only supports reduction operation, such as
mean()
, along 1 axis. i.e.torch.mean(t, [2,3])
is not supported by PyTorch ONNX export. (Although both PyTorch and ONNX.js supports multi-axis reduction ops.)
- PyTorch ONNX export only supports reduction operation, such as
-
ONNX.js has quite a few issues.
- Same input values results in exception error. (ONNX.js issue #53)
- Some ops are slow, such as
Reshape()
, which is converted from PyTorch'sview()
. pow()
+mean()
producesNaN
values in javascript.pow()
op is very buggy.
-
Dynamic tensor shapes exported by PyTorch ONNX is very large and hogs memory like hell.
- If any op node depends on input/out tensor shape dynamically when doing inferencing, the result ONNX model graph can be absurdly huge (.onnx file at ~350MB) and highly complex (Composed of multiple
Reshape
andGather
ops). Although still works, it is not practical to use such model files in web browsers.
- If any op node depends on input/out tensor shape dynamically when doing inferencing, the result ONNX model graph can be absurdly huge (.onnx file at ~350MB) and highly complex (Composed of multiple
-
ONNX.js support for Mobile devices, such as Android, is still not stable.
- The web site has
mosaic zero-pad
as first model as some mobile devices does not runpad
op correctly usingwebgl
backend. All models should work correctly on desktop browsers. - As it can be seen, 'zero-pad' generates a lower quality stylized output.
- The web site has
For more details on tweaking and working around the differences between PyTorch, Exporter and ONNX.js, please see PyTorch fast-neural-style for web.
ONNX.js can be served locally by node.js
via npm
.
Windows npm
installer:
https://nodejs.org/en/#download
Ubuntu npm
installation:
sudo apt install nodejs npm
Install node.js http-server module
npm install http-server -g
Run node.js server locally:
http-server . -c-1 -p 3000
Open a browser and go to the following URL:
http://localhost:3000