Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
*.pyc
__pycache__
vivado_prj
.vscode
.vscode
my-hls-test
*.tar.gz
8 changes: 5 additions & 3 deletions hls-writer/hls_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import print_function
import six
import re
import numpy as np
from enum import Enum
from collections import OrderedDict
Expand Down Expand Up @@ -81,7 +82,7 @@ def get_input_variables(self):
def get_output_variables(self):
variables = []
for out in self.outputs:
variables.append(self.graph[out].get_output_variable())
variables.append(self.output_vars[out])
return variables

def get_layer_output_variable(self, output_name):
Expand All @@ -92,6 +93,7 @@ def __init__(self, var_name, type_name, precision, **kwargs):
self.name = var_name.format(**kwargs)
self.type = type_name.format(**kwargs)
self.precision = precision
self.cppname = re.sub(r'\W|^(?=\d)','_', self.name)

class ArrayVariable(Variable):
def __init__(self, shape, dim_names, var_name='layer{index}', type_name='layer{index}_t', precision=None, pragma='partition', **kwargs):
Expand Down Expand Up @@ -133,7 +135,7 @@ def get_shape(self):

def definition_cpp(self):
array_shape = '*'.join([str(k) for k in self.dim_names])
return '{type} {name}[{shape}]'.format(type=self.type, name=self.name, shape=array_shape)
return '{type} {name}[{shape}]'.format(type=self.type, name=self.cppname, shape=array_shape)

def size(self):
nelem = 1
Expand Down Expand Up @@ -312,7 +314,7 @@ class Dense(Layer):
def initialize(self):
shape = [self.attributes['n_out']]
dims = ['N_LAYER_{}'.format(self.index)]
quantize = self.get_attr('quantize')
quantize = self.get_attr('quantize', default=0)
self.add_output_variable(shape, dims)
self.add_weights(quantize=quantize)
self.add_bias(quantize=quantize)
Expand Down
14 changes: 7 additions & 7 deletions hls-writer/hls_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ def write_project_cpp(model):
#Add input/output type
elif '//hls-fpga-machine-learning insert IO' in line:
newline = line
all_inputs = [i.name for i in model_inputs]
all_outputs = [o.name for o in model_outputs]
all_inputs = [i.cppname for i in model_inputs]
all_outputs = [o.cppname for o in model_outputs]
if model.get_config_value("IOType") == "io_parallel":
for i in model_inputs: newline += indent + '#pragma HLS ARRAY_RESHAPE variable={} complete dim=0 \n'.format(i.name)
for o in model_outputs: newline += indent + '#pragma HLS ARRAY_RESHAPE variable={} complete dim=0 \n'.format(o.name)
for i in model_inputs: newline += indent + '#pragma HLS ARRAY_RESHAPE variable={} complete dim=0 \n'.format(i.cppname)
for o in model_outputs: newline += indent + '#pragma HLS ARRAY_RESHAPE variable={} complete dim=0 \n'.format(o.cppname)
newline += indent + '#pragma HLS INTERFACE ap_vld port={},{} \n'.format(','.join(all_inputs), ','.join(all_outputs))
newline += indent + '#pragma HLS PIPELINE \n'
if model.get_config_value("IOType") == "io_serial":
Expand Down Expand Up @@ -272,15 +272,15 @@ def write_test_bench(model):
output_size_vars = ','.join(['size_out{}'.format(o) for o in range(1, len(model.get_output_variables()) + 1)])
newline += size_str.format(input_size_vars, output_size_vars)

input_vars = ','.join([i.name for i in model.get_input_variables()])
output_vars = ','.join([o.name for o in model.get_output_variables()])
input_vars = ','.join([i.cppname for i in model.get_input_variables()])
output_vars = ','.join([o.cppname for o in model.get_output_variables()])
top_level = ' {}({},{},{},{});\n'.format(model.get_project_name(), input_vars, output_vars, input_size_vars, output_size_vars)
newline += top_level
elif '//hls-fpga-machine-learning insert output' in line:
newline = line
for out in model.get_output_variables():
newline += ' for(int i = 0; i < {}; i++) {{\n'.format(out.size_cpp())
newline += ' std::cout << {}[i] << " ";\n'.format(out.name)
newline += ' std::cout << {}[i] << " ";\n'.format(out.cppname)
newline += ' }\n'
newline += ' std::cout << std::endl;\n'
else:
Expand Down
29 changes: 29 additions & 0 deletions onnx-to-hls/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# ONNX to HLS

Reads ONNX model file which contains architecture as well as weights and biases.

# Instructions to run

```python onnx-to-hls.py -c onnx-config.yml```

# Configuration

Configuration options for the HLS translation of Keras models.

*OnnxModel*: For ONNX translation, you are required to provide `onnx` model file.
Examples are in the directory: `example-onnx-model-files`

*OutputDir*: Directory where your HLS project will go.

*IOType*: We provide 2 options for the way inputs are input to the architecture, serially or in parallel. The keywords are `io_serial` or `io_parallel`.

*ReuseFactor*: For the running mode `io_parallel`, the calculations do not have to be fully parallelized but resources can be reused at the cost of higher latency. A `ReuseFactor: 1` means fully parallelized and no resources are reused.

*DefaultPrecision*: This is the default type of the weights, biases, accumulators, input and output vectors. This can then be further modified by the `firmware/parameters.h` file generated in your HLS project.

# Running HLS

```
cd my-hls-test
vivado_hls -f build_prj.tcl
```
45 changes: 45 additions & 0 deletions onnx-to-hls/converters/keras-to-onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import onnxmltools
from onnx import checker, shape_inference, optimizer
from onnx.utils import polish_model
from keras.models import Model
from keras.models import model_from_json
import argparse

def keras_to_onnx():
parser = argparse.ArgumentParser(description='')
parser.add_argument("-m", action='store', dest='model',
help="Keras model file (.json).")
parser.add_argument("-w", action='store', dest='weights',
help="Keras model weights (.h5).")
parser.add_argument("-o", action='store', dest='output',
help="Output file name (.onnx).")
args = parser.parse_args()

if not args.model: parser.error('Model file needs to be specified.')
if not args.weights: parser.error('Weights file needs to be specified.')
if not args.weights: parser.error('Output file needs to be specified.')

# Load Keras model and its weights
with open(args.model, 'r') as json_file:
keras_model = model_from_json(json_file.read())

keras_model.load_weights(args.weights)
#keras_model.summary()

# Save to ONNX format
onnx_model = onnxmltools.convert_keras(keras_model)

# Check model
checker.check_model(onnx_model)

# Infer shape
onnx_model = shape_inference.infer_shapes(onnx_model)

passes = ['fuse_matmul_add_bias_into_gemm', 'fuse_consecutive_transposes', 'fuse_transpose_into_gemm']
onnx_model = optimizer.optimize(onnx_model, passes)
onnx_model = polish_model(onnx_model)
onnxmltools.utils.save_model(onnx_model, args.output)


if __name__ == "__main__":
keras_to_onnx()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
9 changes: 9 additions & 0 deletions onnx-to-hls/onnx-config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
OnnxModel: example-onnx-model-files/three_layer_keras.onnx
OutputDir: my-hls-test
ProjectName: myproject
XilinxPart: xcku115-flvb2104-2-i
ClockPeriod: 5

IOType: io_parallel # options: io_serial/io_parallel
ReuseFactor: 1
DefaultPrecision: ap_fixed<16,6>
Loading