diff --git a/source/tests/test_fitting_ener_type.py b/source/tests/test_fitting_ener_type.py index 1eb1002147..cf4a891f96 100644 --- a/source/tests/test_fitting_ener_type.py +++ b/source/tests/test_fitting_ener_type.py @@ -81,6 +81,7 @@ def test_fitting(self): type_embedding = type_embedding.reshape([ntypes,-1]) atom_ener = fitting.build(tf.convert_to_tensor(dout), t_natoms, + tf.constant(numb_test, dtype=tf.int32), {'type_embedding':tf.convert_to_tensor(type_embedding)}, reuse = False, suffix = "se_a_type_fit_")