diff --git a/flashlight/fl/nn/modules/Module.h b/flashlight/fl/nn/modules/Module.h index 98421e0d9..1c9007239 100644 --- a/flashlight/fl/nn/modules/Module.h +++ b/flashlight/fl/nn/modules/Module.h @@ -221,10 +221,7 @@ class FL_API BinaryModule : public Module { FL_SAVE_LOAD_WITH_BASE(Module) }; -template < - typename... Args, - typename = - std::enable_if_t<(std::is_same_v> && ...)>> +template auto Module::forward(Args&&... inputs) { if constexpr (sizeof...(Args) == 1) { if (auto unaryModulePtr = dynamic_cast(this)) { @@ -233,7 +230,7 @@ auto Module::forward(Args&&... inputs) { auto output = forward(std::vector{std::forward(inputs)...}); if (output.size() > 1) { throw std::runtime_error( - "Forward interface expects 1 output argument. Wrap the input arguments in a vector to avoid using the unary interface."); + "Forward interface expects 1 output argument. Wrap the input argument in a vector to avoid using the unary interface."); } return std::move(output.front()); } else {