Skip to content

Commit

Permalink
support TF converter
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangxiang1993 committed Dec 11, 2018
1 parent e119e10 commit eb3bcef
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 39 deletions.
4 changes: 2 additions & 2 deletions Tools/WinMLDashboard/ThirdPartyNotice.txt
Expand Up @@ -364,7 +364,7 @@ SOFTWARE.


================================================
typescript 2.9.2
typescript 3.2.2
Copyright (c) Microsoft Corporation. All rights reserved.

electron-squirrel-startup 1.0.0
Expand Down Expand Up @@ -392,7 +392,7 @@ and limitations under the License.
@types/prop-types 15.5.4
@types/react 16.4.6
@types/react-dom 16.0.6
@types/react-redux 6.0.4
@types/react-redux 6.0.11
@types/yauzl 2.9.0
@types/webdriverio 4.13.0
=====
Expand Down
4 changes: 2 additions & 2 deletions Tools/WinMLDashboard/package.json
Expand Up @@ -39,7 +39,7 @@
"@types/prop-types": "^15.5.4",
"@types/react": "^16.4.7",
"@types/react-dom": "^16.0.6",
"@types/react-redux": "^6.0.5",
"@types/react-redux": "^6.0.11",
"@types/react-select": "^2.0.6",
"@types/webdriverio": "^4.13.0",
"@types/yauzl": "^2.9.0",
Expand All @@ -61,7 +61,7 @@
"react-select": "^2.1.1",
"redux": "^4.0.0",
"spectron": "^5.0.0",
"typescript": "^2.9.2",
"typescript": "^3.2.2",
"yauzl": "^2.10.0"
}
}
27 changes: 25 additions & 2 deletions Tools/WinMLDashboard/public/convert.py
Expand Up @@ -6,8 +6,9 @@

def parse_args():
parser = argparse.ArgumentParser(description='Convert model to ONNX.')
parser.add_argument('source', help='source CoreML or Keras model')
parser.add_argument('source', help='source model')
parser.add_argument('framework', help='source framework model comes from')
parser.add_argument('outputNames', help='names of output nodes')
parser.add_argument('destination', help='destination ONNX model (ONNX or prototxt extension)')
parser.add_argument('--name', default='WimMLDashboardConvertedModel', help='(ONNX output only) model name')
return parser.parse_args()
Expand Down Expand Up @@ -70,6 +71,28 @@ def libSVM_converter(args):
input_features=[('input', FloatTensorType([1, 'None']))])
save_onnx(onnx_model, args.destination)

def convert_tensorflow_file(filename, output_names, destination, debug=True):
from tensorflow.core.framework import graph_pb2
import tensorflow as tf
import tf2onnx
import onnx
graph_def = graph_pb2.GraphDef()
with open(filename, 'rb') as file:
graph_def.ParseFromString(file.read())
g = tf.import_graph_def(graph_def, name='')
with tf.Session(graph=g) as sess:
converted_model = winmltools.convert_tensorflow(sess.graph, continue_on_error=True, verbose=True, output_names=output_names, build_number=17763)
onnx.checker.check_model(converted_model)
if debug:
with open(destination, 'wb') as file:
file.write(converted_model.SerializeToString())
tf.reset_default_graph()

def tensorFlow_converter(args):
convert_tensorflow_file(args.source, args.outputNames.split(), args.destination)



def onnx_converter(args):
onnx_model = winmltools.load_model(args.source)
save_onnx(onnx_model, args.destination)
Expand All @@ -80,7 +103,7 @@ def onnx_converter(args):
'scikit-learn': scikit_learn_converter,
'xgboost': xgboost_converter,
'libsvm': libSVM_converter,
#'tensorflow': TensorFlow_converter
'tensorflow': tensorFlow_converter
}

suffix_converters = {
Expand Down
17 changes: 5 additions & 12 deletions Tools/WinMLDashboard/public/requirements.txt
@@ -1,19 +1,12 @@
git+https://github.com/apple/coremltools
lightgbm==2.2.1
winmltools==1.2.0.912
tensorflow==1.11.0
xgboost==0.80
h5py==2.8.0
Keras==2.2.2
Keras-Applications==1.0.4
Keras-Preprocessing==1.0.2
numpy==1.15.0
tensorflow==1.12.0
Keras==2.2.4
onnx==1.2.3
onnxmltools==1.2.0.116
protobuf==3.6.0
Keras-Applications==1.0.6
Keras-Preprocessing==1.0.5
protobuf==3.6.1
PyYAML==3.13
scikit-learn==0.19.2
scipy==1.1.0
six==1.10.0
typing==3.6.4
typing-extensions==3.6.5
18 changes: 13 additions & 5 deletions Tools/WinMLDashboard/src/view/convert/View.css
Expand Up @@ -6,10 +6,14 @@
flex-direction: column;
}

.ModelConvertBrowser {
.ModelConvert {
width: 100%;
}

.FrameworkOptions, .ONNXVersionOptions, .inputShape {
width: 20%;
}

.ms-TextField, .ConvertViewControls {
flex: 1;
}
Expand All @@ -25,10 +29,14 @@
font-size: 12px;
}

#ConverterModelInputBrowse {
margin: 24px 5px 0px;
}

#ConvertButton {
margin: 5px 0px;
}

.label {
width: 10%;
}

.hidden {
display: none;
}
54 changes: 38 additions & 16 deletions Tools/WinMLDashboard/src/view/convert/View.tsx
Expand Up @@ -4,7 +4,7 @@ import { connect } from 'react-redux';

import { DefaultButton } from 'office-ui-fabric-react/lib/Button';
import { ChoiceGroup, IChoiceGroupOption } from 'office-ui-fabric-react/lib/ChoiceGroup';
// import { MessageBar, MessageBarType } from 'office-ui-fabric-react/lib/MessageBar';
import { MessageBar, MessageBarType } from 'office-ui-fabric-react/lib/MessageBar';
import { Spinner } from 'office-ui-fabric-react/lib/Spinner';
import { TextField } from 'office-ui-fabric-react/lib/TextField';
import Select from 'react-select';
Expand Down Expand Up @@ -47,6 +47,7 @@ interface IComponentState {
currentStep: Step,
error?: Error | string,
framework: string,
outputNames: string,
source?: string,
}

Expand All @@ -56,11 +57,12 @@ class ConvertView extends React.Component<IComponentProperties, IComponentState>
constructor(props: IComponentProperties) {
super(props);
const error = isWeb() ? "The converter can't be run in the web interface" : undefined;
this.state = {
this.state = {
console: '',
currentStep: Step.Idle,
error,
framework: '',
outputNames: '',
};
log.info("Convert view is created.");
}
Expand Down Expand Up @@ -98,6 +100,13 @@ class ConvertView extends React.Component<IComponentProperties, IComponentState>
}

private getView() {
const { error } = this.state;
if (error) {
const message = typeof error === 'string' ? error : (`${error.stack ? `${error.stack}: ` : ''}${error.message}`);
// tslint:disable-next-line:no-console
console.log(message);
return <MessageBar messageBarType={MessageBarType.error}>{message}</MessageBar>
}
switch (this.state.currentStep) {
case Step.Downloading:
return <Spinner label="Downloading Python..." />;
Expand Down Expand Up @@ -161,6 +170,8 @@ class ConvertView extends React.Component<IComponentProperties, IComponentState>
await downloadPip(this.outputListener);
log.info("start downloading python environment.");
await pip(['install', packagedFile('libsvm-3.22-cp36-cp36m-win_amd64.whl')], this.outputListener);
await pip(['install', packagedFile('winmltools-1.3.0a0-py2.py3-none-any.whl')], this.outputListener);
await pip(['install', packagedFile('tf2onnx-0.4.0-py3-none-any.whl')], this.outputListener);
this.setState({ currentStep: Step.InstallingRequirements });
await pip(['install', '-r', packagedFile('requirements.txt'), '--no-warn-script-location'], this.outputListener);
this.setState({ currentStep: Step.Idle });
Expand All @@ -183,27 +194,37 @@ class ConvertView extends React.Component<IComponentProperties, IComponentState>
}

private converterView = () => {
const options = [
const frameworkOptions = [
{ value: 'Coreml', label: 'Coreml' },
{ value: 'Keras', label: 'Keras' },
{ value: 'scikit-learn', label: 'scikit-learn' },
{ value: "xgboost", label: 'xgboost' },
{ value: 'libSVM', label: 'libSVM' }
{ value: 'libSVM', label: 'libSVM' },
{ value: 'TensorFlow', label: 'TensorFlow' },
];
return (
<div>
<div className='DisplayFlex ModelConvertBrowser'>
<TextField id='modelToConvert' placeholder='Path' value={this.state.source} label='Model to convert' onChanged={this.setSource} />
<div className="ModelConvert">
<div className='DisplayFlex'>
<label className='label'>Model to convert: </label>
<TextField id='modelToConvert' placeholder='Path' value={this.state.source} onChanged={this.setSource} />
<DefaultButton id='ConverterModelInputBrowse' text='Browse' onClick={this.browseSource}/>
</div>
<div className='Frameworks'>
<p>Source Framework: </p>
<Select
<br />
<div className='DisplayFlex'>
<label className='label'>Source Framework: </label>
<Select className='FrameworkOptions'
value={this.newOption(this.state.framework)}
onChange={this.setFramework}
options={options}
options={frameworkOptions}
/>
</div>
<br />
<div className={this.state.framework === 'TensorFlow' ? ' ' : 'hidden'}>
<div className='DisplayFlex'>
<label className='label'>Output Names: </label>
<TextField id='outputNames' className='outputNames' placeholder='output:0, output:1' value={this.state.outputNames} onChanged={this.setOutputNames} />
</div>
</div>
<DefaultButton id='ConvertButton' text='Convert' disabled={!this.state.source || !this.state.framework} onClick={this.convert}/>
</div>
);
Expand All @@ -212,9 +233,12 @@ class ConvertView extends React.Component<IComponentProperties, IComponentState>
private newOption = (framework: string):ISelectOpition => {
return {
label: framework,
value: framework
value: framework,
}
}
private setOutputNames = (outputNames: string) => {
this.setState({outputNames})
}

private setFramework = (framework: ISelectOpition) => {
this.setState({framework: framework.value})
Expand Down Expand Up @@ -254,10 +278,8 @@ class ConvertView extends React.Component<IComponentProperties, IComponentState>

private convert = async () => {
this.initializeState();
const source = this.state.source!;
const framework = this.state.framework;

if (!framework) {
if (!this.state.framework) {
return;
}
const convertDialogOptions = {
Expand All @@ -268,7 +290,7 @@ class ConvertView extends React.Component<IComponentProperties, IComponentState>

this.setState({ currentStep: Step.Converting });
try {
await python([packagedFile('convert.py'), source, framework, packagedFile('tempConvertResult.onnx')], {}, this.outputListener);
await python([packagedFile('convert.py'), this.state.source!, this.state.framework, this.state.outputNames, packagedFile('tempConvertResult.onnx')], {}, this.outputListener);
} catch (e) {
this.logError(e);
this.printMessage("\n------------------------------------\nConversion failed!\n")
Expand Down

0 comments on commit eb3bcef

Please sign in to comment.