Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support negative_slope in quantized_relu #987

Merged
merged 4 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions hls4ml/converters/keras/qkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ def parse_qactivation_layer(keras_layer, input_names, input_shapes, data_reader)
layer['slope_prec'] = FixedPrecisionType(width=2, integer=0, signed=False)
layer['shift_prec'] = FixedPrecisionType(width=2, integer=0, signed=False)
layer['activation'] = activation_config['class_name'].replace('quantized_', 'hard_')
elif activation_config['class_name'] == 'quantized_relu' and activation_config['config']['negative_slope'] != 0:
layer['class_name'] = 'LeakyReLU'
layer['activation'] = activation_config['class_name'].replace('quantized_', 'leaky_')
layer['activ_param'] = activation_config['config']['negative_slope']
else:
layer['class_name'] = 'Activation'
layer['activation'] = activation_config['class_name'].replace('quantized_', '')
Expand Down
6 changes: 3 additions & 3 deletions hls4ml/model/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,10 +588,10 @@ def get_ymodel_keras(keras_model, X):
# Note that if the layer is a standalone activation layer then skip this
name = layer.name
if (
hasattr(layer, "activation")
and hasattr(layer.activation, "__name__")
and layer.activation.__name__ != "linear"
hasattr(layer, 'activation')
and layer.activation is not None
and not isinstance(layer, (keras.layers.Activation, qkeras.qlayers.QActivation))
and layer.activation.__name__ != 'linear'
):
tmp_activation = layer.activation
layer.activation = None
Expand Down
7 changes: 5 additions & 2 deletions hls4ml/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,11 @@ def _get_precision_from_quantizer(quantizer):
rnd = "AP_RND_CONV"
overflow = "AP_SAT"
if quantizer['class_name'] in ('quantized_relu', 'quantized_relu_po2'):
signed = False
integer -= 1
if quantizer['config']['negative_slope'] != 0.0:
signed = True
else:
signed = False
integer -= 1
elif quantizer['class_name'] == 'quantized_tanh':
overflow = "AP_SAT_SYM" if quantizer['config']['symmetric'] else "AP_SAT"
integer = 1
Expand Down
38 changes: 38 additions & 0 deletions test/pytest/test_qkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,44 @@ def test_quantizer(randX_1000_1, quantizer, backend, io_type):
np.testing.assert_array_equal(y_qkeras, y_hls4ml)


@pytest.mark.parametrize(
'quantizer',
[
(quantized_relu(4, negative_slope=0.5)),
(quantized_relu(8, 4, negative_slope=1.0)),
(quantized_relu(10, 2, negative_slope=0.25)),
],
)
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
def test_relu_negative_slope(randX_1000_1, quantizer, backend, io_type):
'''
Test a a transformation of quantized_relu with negative_slope to leaky_relu activation layer.
'''
X = randX_1000_1
X = -X # Make it negative so leaky relu does something
X = np.round(X * 2**10) * 2**-10 # make it an exact ap_fixed<16,6>
model = Sequential()
model.add(QActivation(input_shape=(1,), activation=quantizer, name='quantizer'))
model.compile()

config = hls4ml.utils.config_from_keras_model(model, granularity='name')
output_dir = str(
test_root_path
/ 'hls4mlprj_qkeras_leaky_relu_{}_{}_neg_slope_{}_{}_{}'.format(
quantizer.bits, quantizer.integer, quantizer.negative_slope, backend, io_type
)
)
hls_model = hls4ml.converters.convert_from_keras_model(
model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type
)
hls_model.compile()

y_qkeras = model.predict(X)
y_hls4ml = hls_model.predict(X)
np.testing.assert_allclose(y_hls4ml, y_qkeras, rtol=1e-5, atol=0)


@pytest.mark.parametrize(
'weight_quantizer,activation_quantizer,',
[
Expand Down
Loading