Skip to content

Commit

Permalink
Merge pull request #537 from JochiSt/master
Browse files Browse the repository at this point in the history
returning integer from _compute_n_samples
  • Loading branch information
jmduarte committed Jun 19, 2022
2 parents b1b30f3 + ad74c80 commit eb9f3b4
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
32 changes: 18 additions & 14 deletions hls4ml/model/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ def _compute_n_samples(self, x):
if not all([n_samples[i] == n_samples[i+1] for i in range(len(xlist)-1)]):
raise Exception('Input size mismatch, not all inputs match')

return n_sample
return int(n_sample)

def predict(self, x):
top_function, ctype = self._get_top_function(x)
Expand All @@ -651,9 +651,9 @@ def predict(self, x):
for i in range(n_samples):
predictions = [np.zeros(yj.size(), dtype=ctype) for yj in self.get_output_variables()]
if n_inputs == 1:
inp = [x[i]]
inp = [np.asarray(x[i])]
else:
inp = [xj[i] for xj in x]
inp = [np.asarray(xj[i]) for xj in x]
argtuple = inp
argtuple += predictions
argtuple = tuple(argtuple)
Expand Down Expand Up @@ -683,6 +683,7 @@ def trace(self, x):
top_function, ctype = self._get_top_function(x)
n_samples = self._compute_n_samples(x)
n_inputs = len(self.get_input_variables())
n_outputs = len(self.get_output_variables())

class TraceData(ctypes.Structure):
_fields_ = [('name', ctypes.c_char_p),
Expand Down Expand Up @@ -721,16 +722,15 @@ class TraceData(ctypes.Structure):
alloc_func(ctypes.sizeof(ctype))

for i in range(n_samples):
predictions = np.zeros(self.get_output_variables()[0].size(), dtype=ctype)
predictions = [np.zeros(yj.size(), dtype=ctype) for yj in self.get_output_variables()]
if n_inputs == 1:
top_function(x[i], predictions, ctypes.byref(ctypes.c_ushort()), ctypes.byref(ctypes.c_ushort()))
inp = [np.asarray(x[i])]
else:
inp = [xj[i] for xj in x]
argtuple = inp
argtuple += [predictions]
argtuple += [ctypes.byref(ctypes.c_ushort()) for k in range(len(inp)+1)]
argtuple = tuple(argtuple)
top_function(*argtuple)
inp = [np.asarray(xj[i]) for xj in x]
argtuple = inp
argtuple += predictions
argtuple = tuple(argtuple)
top_function(*argtuple)
output.append(predictions)
collect_func(trace_data)
for trace in trace_data:
Expand All @@ -742,15 +742,19 @@ class TraceData(ctypes.Structure):
for key in trace_output.keys():
trace_output[key] = np.asarray(trace_output[key])

#Convert to numpy array
output = np.asarray(output)
# Convert to list of numpy arrays (one for each output)
output = [np.asarray([output[i_sample][i_output] for i_sample in range(n_samples)]) for i_output in range(n_outputs)]

free_func()
finally:
os.chdir(curr_dir)

if n_samples == 1:
if n_samples == 1 and n_outputs == 1:
return output[0][0], trace_output
elif n_outputs == 1:
return output[0], trace_output
elif n_samples == 1:
return [output_i[0] for output_i in output], trace_output
else:
return output, trace_output

Expand Down
4 changes: 3 additions & 1 deletion test/pytest/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ def test_multiple_outputs(batch):
X1 = np.random.randint(0, 100, size=(batch, 10)).astype(float)
y = model.predict(X1)
y_hls = hls_model.predict(X1)
# test trace as well
y_hls, hls_trace = hls_model.trace(X1)
for y_i, y_hls_i in zip(y, y_hls):
y_hls_i = y_hls_i.reshape(y_i.shape)
np.testing.assert_allclose(y_i, y_hls_i, rtol=0)
np.testing.assert_allclose(y_i, y_hls_i, rtol=0)

0 comments on commit eb9f3b4

Please sign in to comment.