diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 43a924fbf..1a5711f57 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -632,7 +632,7 @@ def _compute_n_samples(self, x): if not all([n_samples[i] == n_samples[i+1] for i in range(len(xlist)-1)]): raise Exception('Input size mismatch, not all inputs match') - return n_sample + return int(n_sample) def predict(self, x): top_function, ctype = self._get_top_function(x) @@ -651,9 +651,9 @@ def predict(self, x): for i in range(n_samples): predictions = [np.zeros(yj.size(), dtype=ctype) for yj in self.get_output_variables()] if n_inputs == 1: - inp = [x[i]] + inp = [np.asarray(x[i])] else: - inp = [xj[i] for xj in x] + inp = [np.asarray(xj[i]) for xj in x] argtuple = inp argtuple += predictions argtuple = tuple(argtuple) @@ -683,6 +683,7 @@ def trace(self, x): top_function, ctype = self._get_top_function(x) n_samples = self._compute_n_samples(x) n_inputs = len(self.get_input_variables()) + n_outputs = len(self.get_output_variables()) class TraceData(ctypes.Structure): _fields_ = [('name', ctypes.c_char_p), @@ -721,16 +722,15 @@ class TraceData(ctypes.Structure): alloc_func(ctypes.sizeof(ctype)) for i in range(n_samples): - predictions = np.zeros(self.get_output_variables()[0].size(), dtype=ctype) + predictions = [np.zeros(yj.size(), dtype=ctype) for yj in self.get_output_variables()] if n_inputs == 1: - top_function(x[i], predictions, ctypes.byref(ctypes.c_ushort()), ctypes.byref(ctypes.c_ushort())) + inp = [np.asarray(x[i])] else: - inp = [xj[i] for xj in x] - argtuple = inp - argtuple += [predictions] - argtuple += [ctypes.byref(ctypes.c_ushort()) for k in range(len(inp)+1)] - argtuple = tuple(argtuple) - top_function(*argtuple) + inp = [np.asarray(xj[i]) for xj in x] + argtuple = inp + argtuple += predictions + argtuple = tuple(argtuple) + top_function(*argtuple) output.append(predictions) collect_func(trace_data) for trace in trace_data: @@ -742,15 +742,19 @@ class TraceData(ctypes.Structure): for key in trace_output.keys(): trace_output[key] = np.asarray(trace_output[key]) - #Convert to numpy array - output = np.asarray(output) + # Convert to list of numpy arrays (one for each output) + output = [np.asarray([output[i_sample][i_output] for i_sample in range(n_samples)]) for i_output in range(n_outputs)] free_func() finally: os.chdir(curr_dir) - if n_samples == 1: + if n_samples == 1 and n_outputs == 1: + return output[0][0], trace_output + elif n_outputs == 1: return output[0], trace_output + elif n_samples == 1: + return [output_i[0] for output_i in output], trace_output else: return output, trace_output diff --git a/test/pytest/test_graph.py b/test/pytest/test_graph.py index a1f3d5340..3b79981af 100644 --- a/test/pytest/test_graph.py +++ b/test/pytest/test_graph.py @@ -200,6 +200,8 @@ def test_multiple_outputs(batch): X1 = np.random.randint(0, 100, size=(batch, 10)).astype(float) y = model.predict(X1) y_hls = hls_model.predict(X1) + # test trace as well + y_hls, hls_trace = hls_model.trace(X1) for y_i, y_hls_i in zip(y, y_hls): y_hls_i = y_hls_i.reshape(y_i.shape) - np.testing.assert_allclose(y_i, y_hls_i, rtol=0) \ No newline at end of file + np.testing.assert_allclose(y_i, y_hls_i, rtol=0)