Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for HGQ proxy model #914

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
107 changes: 107 additions & 0 deletions hls4ml/backends/fpga/passes/hgq_proxy_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import numpy as np

from hls4ml.backends import Backend
from hls4ml.backends.template import FunctionCallTemplate
from hls4ml.model.layers import Layer
from hls4ml.model.optimizer import OptimizerPass
from hls4ml.model.optimizer.passes.hgq_proxy_model import FixedPointQuantizer, UnaryLUT
from hls4ml.model.types import Source


def to_apfixed(k, b, i, RND, SAT):
u = 'u' if k == 0 else ''
return f'ap_{u}fixed<{b},{i},AP_{RND},AP_{SAT}>'


def to_acfixed(k, b, i, RND, SAT):
k = 'false' if k == 0 else 'true'
return f'ac_fixed<{b},{i},{k},AC_{RND},AC_{SAT}>'


def generate_mask_fn(
name: str, shape: tuple[int, ...], k: np.ndarray, b: np.ndarray, i: np.ndarray, RND: str, SAT: str, backend: str
) -> str:
"""Generate heterogenous quantization mask function, ONLY works for IOType=io_parallel"""
assert k.shape[0] == b.shape[0] == i.shape[0] == 1
assert backend.lower() in ('quartus', 'vivado', 'vitis'), f'Backend {backend} not tested'
Ks, Bs, Is = k[0], b[0], i[0]
Ks, Bs, Is = np.broadcast_to(Ks, shape), np.broadcast_to(Bs, shape), np.broadcast_to(Is, shape)
Ks, Bs, Is = Ks.ravel(), Bs.ravel(), Is.ravel()
masks = []
to_fixed = to_acfixed if backend.lower() == 'quartus' else to_apfixed
for idx, (k, b, i) in enumerate(zip(Ks, Bs, Is)):
if b == 0:
fn = f'out[{idx}] = 0;'
else:
fn = f'out[{idx}] = {to_fixed(k, b, i, RND, SAT)}(inp[{idx}]);'
masks.append(f' {fn}')
body = "\n".join(masks)
mask_fn = f'''
template<typename input_t, typename output_t>
void {name}(input_t *inp, output_t *out) {{
#pragma HLS INLINE

{body}
}}
'''
return mask_fn


class ProcessFixedPointQuantizerLayer(OptimizerPass):
def match(self, node: Layer):
return isinstance(node, FixedPointQuantizer)

def transform(self, model, node: FixedPointQuantizer):
if node.fusible:
model.remove_node(node, rewire=True)
return True

if model.config.config['IOType'] != 'io_parallel':
raise NotImplementedError('Heterogenous quantization for activations is only supported with IOType=io_parallel')

backend = model.config.config['Backend']

name = node.name

assert node.mask_kbi is not None
k, b, i = node.mask_kbi
RND = node.RND
SAT = node.SAT
mask_fn: str = generate_mask_fn(name, node.get_input_variable().shape, k, b, i, RND, SAT, backend)

node.set_attr('mask_fn_codegen', Source(mask_fn))


class ProcessFixedPointQuantizerCall(FunctionCallTemplate):
def __init__(self):
super().__init__(FixedPointQuantizer, include_header=[])
self.template = 'nnet::{name}<{input_t}, {output_t}>({input}, {output});'

def format(self, node):
params = self._default_function_params(node)

return self.template.format(**params)


class ProcessUnaryLUTCall(FunctionCallTemplate):
def __init__(self):
super().__init__(UnaryLUT, include_header=[])
self.template = 'nnet::unary_lut<{input_t}, {output_t}, {config}>({input}, {output}, {table});'
self.include_header = [
'nnet_utils/nnet_activation.h',
'nnet_utils/nnet_activation_stream.h',
]

def format(self, node):
params = self._default_function_params(node)
node.attributes['result_t'].precision = node.attributes['table_t'].precision
params['config'] = f'unary_lut_config{node.index}'
params['table'] = node.get_weights('table').name

return self.template.format(**params)


def register_hgq_proxy_model(backend: Backend):
backend.register_pass('process_fixed_point_quantizer_layer', ProcessFixedPointQuantizerLayer)
backend.register_template(ProcessFixedPointQuantizerCall)
backend.register_template(ProcessUnaryLUTCall)
3 changes: 2 additions & 1 deletion hls4ml/backends/quartus/passes/core_templates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from hls4ml.backends.backend import get_backend
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
from hls4ml.model.layers import Activation, BatchNormalization, Dense, HardActivation, ParametrizedActivation, PReLU, Softmax
from hls4ml.model.optimizer.passes.hgq_proxy_model import UnaryLUT

# Dense templates

Expand Down Expand Up @@ -152,7 +153,7 @@ def format(self, node):

class ActivationConfigTemplate(LayerConfigTemplate):
def __init__(self):
super().__init__((Activation, ParametrizedActivation, PReLU))
super().__init__((Activation, ParametrizedActivation, PReLU, UnaryLUT))
self.template = activ_config_template

def format(self, node):
Expand Down
26 changes: 24 additions & 2 deletions hls4ml/backends/quartus/quartus_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from contextlib import contextmanager
from warnings import warn

import numpy as np

Expand Down Expand Up @@ -73,6 +74,7 @@ def _register_flows(self):
'quartus:inplace_stream_flatten',
'quartus:skip_softmax',
'quartus:fix_softmax_table_size',
'quartus:process_fixed_point_quantizer_layer',
'infer_precision_types',
]
optimization_flow = register_flow('optimize', optimization_passes, requires=[init_flow], backend=self.name)
Expand Down Expand Up @@ -265,7 +267,17 @@ def init_conv1d(self, layer):
n_in, n_out = self.get_layer_mult_size(layer)
self.set_target_reuse_factor(layer)
self.set_closest_reuse_factor(layer, n_in, n_out)
layer.set_attr('parallelization', layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', 1))

# Not overriding user parallelization factor, if already set and user has not specified a value
user_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', None)
layer_pf = layer.get_attr('parallelization_factor', None)
chosen_pf = user_pf or layer_pf or 1
if user_pf is not None and layer_pf is not None:
if user_pf != layer_pf:
warn(
f'For layer {layer.name}, parallelization factor of {layer_pf} is defined in the proxy-model, but is overridden by the user to {user_pf}.' # noqa: E501
)
layer.set_attr('parallelization', chosen_pf)

# impl_filt_width determines the filter size post-Winograd transformation
layer.set_attr('impl_filt_width', layer.get_attr('filt_width'))
Expand Down Expand Up @@ -295,7 +307,17 @@ def init_conv2d(self, layer):
n_in, n_out = self.get_layer_mult_size(layer)
self.set_target_reuse_factor(layer)
self.set_closest_reuse_factor(layer, n_in, n_out)
layer.set_attr('parallelization', layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', 1))

# Not overriding user parallelization factor, if already set and user has not specified a value
user_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', None)
layer_pf = layer.get_attr('parallelization_factor', None)
chosen_pf = user_pf or layer_pf or 1
if user_pf is not None and layer_pf is not None:
if user_pf != layer_pf:
warn(
f'For layer {layer.name}, parallelization factor of {layer_pf} is defined in the proxy-model, but is overridden by the user to {user_pf}.' # noqa: E501
)
layer.set_attr('parallelization', chosen_pf)

# impl_filt_width & impl_filt_height determine the filter size post-Winograd transformation
layer.set_attr('impl_filt_height', layer.get_attr('filt_height'))
Expand Down
3 changes: 2 additions & 1 deletion hls4ml/backends/vivado/passes/core_templates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from hls4ml.backends.backend import get_backend
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
from hls4ml.model.layers import Activation, BatchNormalization, Dense, HardActivation, ParametrizedActivation, PReLU, Softmax
from hls4ml.model.optimizer.passes.hgq_proxy_model import UnaryLUT

# Dense templates

Expand Down Expand Up @@ -144,7 +145,7 @@ def format(self, node):

class ActivationConfigTemplate(LayerConfigTemplate):
def __init__(self):
super().__init__((Activation, ParametrizedActivation, PReLU))
super().__init__((Activation, ParametrizedActivation, PReLU, UnaryLUT))
self.template = activ_config_template

def format(self, node):
Expand Down
28 changes: 26 additions & 2 deletions hls4ml/backends/vivado/vivado_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sys
from warnings import warn

import numpy as np

Expand Down Expand Up @@ -107,6 +108,7 @@ def _register_flows(self):
'vivado:inplace_stream_flatten',
'vivado:skip_softmax',
'vivado:fix_softmax_table_size',
'vivado:process_fixed_point_quantizer_layer',
'infer_precision_types',
]
optimization_flow = register_flow('optimize', optimization_passes, requires=[init_flow], backend=self.name)
Expand Down Expand Up @@ -266,7 +268,17 @@ def init_conv1d(self, layer):
layer.set_attr('strategy', 'latency')

out_width = layer.get_output_variable().shape[0]
chosen_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', 1)

# Not overriding user parallelization factor, if already set and user has not specified a value
user_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', None)
layer_pf = layer.get_attr('parallelization_factor', None)
chosen_pf = user_pf or layer_pf or 1
if user_pf is not None and layer_pf is not None:
if user_pf != layer_pf:
warn(
f'For layer {layer.name}, parallelization factor of {layer_pf} is defined in the proxy-model, but is overridden by the user to {user_pf}.' # noqa: E501
)

valid_pf = self.get_valid_conv_partition_splits(1, out_width)
if chosen_pf not in valid_pf:
closest_pf = self.get_closest_reuse_factor(valid_pf, chosen_pf)
Expand All @@ -278,6 +290,7 @@ def init_conv1d(self, layer):
else:
closest_pf = chosen_pf
layer.set_attr('n_partitions', out_width // closest_pf)
layer.set_attr('parallelization_factor', closest_pf)

layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower())

Expand Down Expand Up @@ -332,7 +345,17 @@ def init_conv2d(self, layer):

out_height = layer.get_output_variable().shape[0]
out_width = layer.get_output_variable().shape[1]
chosen_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', 1)

# Not overriding user parallelization factor, if already set and user has not specified a value
user_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', None)
layer_pf = layer.get_attr('parallelization_factor', None)
chosen_pf = user_pf or layer_pf or 1
if user_pf is not None and layer_pf is not None:
if user_pf != layer_pf:
warn(
f'For layer {layer.name}, parallelization factor of {layer_pf} is defined in the proxy-model, but is overridden by the user to {user_pf}.' # noqa: E501
)

valid_pf = self.get_valid_conv_partition_splits(out_height, out_width)
if chosen_pf not in valid_pf:
closest_pf = self.get_closest_reuse_factor(valid_pf, chosen_pf)
Expand All @@ -344,6 +367,7 @@ def init_conv2d(self, layer):
else:
closest_pf = chosen_pf
layer.set_attr('n_partitions', out_height * out_width // closest_pf)
layer.set_attr('parallelization_factor', closest_pf)

layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower())

Expand Down
37 changes: 37 additions & 0 deletions hls4ml/converters/keras/hgq_proxy_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from hls4ml.converters.keras_to_hls import KerasReader, keras_handler, parse_default_keras_layer


@keras_handler('FixedPointQuantizer', 'HGQ>FixedPointQuantizer')
def fixedpoint_quantizer_handler(keras_layer, input_names, input_shapes, data_reader: KerasReader):
config = parse_default_keras_layer(keras_layer, input_names)

name = config['name']
fusible = keras_layer['config']['fusible']
config['RND'] = keras_layer['config']['RND']
config['SAT'] = keras_layer['config']['SAT']
config['fusible'] = fusible
if not fusible:
k = data_reader.get_weights_data(name, 'keep_negative')
b = data_reader.get_weights_data(name, 'bits')
i = data_reader.get_weights_data(name, 'integers')
config['mask_kbi'] = k, b, i
config['overrides'] = keras_layer['config']['overrides']

layer = config
return layer, input_shapes[0]


@keras_handler('UnaryLUT', 'HGQ>UnaryLUT')
def unary_lut_keras_handler(keras_layer, input_names, input_shapes, data_reader: KerasReader):
config = parse_default_keras_layer(keras_layer, input_names)

table = data_reader.get_weights_data(config['name'], 'table')
k, i, f = keras_layer['config']['kif_out']
k, b, i = k, k + i + f, k + i
config['table_t'] = f'{"" if k else "u"}fixed<{b},{i}>'
config['table'] = table
config['table_size'] = len(table)
config['activation'] = 'unary_lut'

layer = config
return layer, input_shapes[0]
2 changes: 2 additions & 0 deletions hls4ml/converters/keras_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ def parse_keras_model(model_arch, reader):
'Softmax',
'TernaryTanh',
'HardActivation',
'UnaryLUT',
'HGQ>UnaryLUT',
]
# Recurrent layers
recurrent_layers = ['SimpleRNN', 'LSTM', 'GRU']
Expand Down
1 change: 1 addition & 0 deletions hls4ml/model/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
'qkeras_factorize_alpha',
'extract_ternary_threshold',
'fuse_consecutive_batch_normalization',
'enforce_proxy_model_embedded_config',
],
) # TODO Maybe not all QKeras optmizers belong here?

Expand Down
Loading
Loading