#### SafeIO

It seems that we can easily define system objects, as the previous example showed. But now we find ourselves in a subtle Pythonic pickle: what if our state dynamics only depended upon x, that is, ``f(x)`` instead of ``f(x,u,t)``? Any other classes or modules that compose System objects would expect ``System.f`` to have the arguments x, u, and t, in that order. We could simply define all functions like this, and choose not to use arguments that are not needed, but this breaks compatibability with existing tools. For instance, SciPy expects its functions in ``f(t, x)`` format, while JAX prefers ``f(x, u, t)`` or ``f(x, t)``. 

Furthermore, the flexibility of Python functions can be a double-edged sword. While it enables rapid prototyping, it also allows functions with inconsistent signatures, missing type annotations, ambiguous return types, or dynamic behavior that can silently break downstream logic.

Consider the pendulum example with C:

.. code-block:: c

   // C: fixed signature, explicit memory, compiler-enforced interface
   void dynamics(const double *x, const double *u, double t, double *x_dot_out) {
       x_dot_out[0] = x[1];
       x_dot_out[1] = -x[0];
   }

In C, functions must declare exactly what inputs they need and what outputs they produce.
Return types and pointer sizes are enforced at compile time, and passing the wrong number
of arguments or mismatched types raises immediate, traceable errors.

In Python, however, the same function might look like:

.. code-block:: python

   def f(x, u, t):
       return np.vstack([x[1], -x[0]])

We are at the mercy of runtime users to pass in the correct values, and to have the output be correct and of the correct shape. These kinds of mismatches are notoriously difficult to trace—especially when using external libraries like TensorFlow, PyTorch, or JAX that wrap or recompile functions dynamically.

Without a consistent interface, chaos creeps in. 

Hence, a utility class was created: ``SafeIO``.

This class validates user-defined functions at the time of registration (not just runtime). It also injects only necessary arguments and any keyword argumetns as needed, and enforces that all returned values are properly typed and shaped. It does this with the smart call function.

In [None]:
for key, value in sys.safeio.__dict__.items():
    print(f"{key}: {value}")

parent: <pykal_core.control_system.system.System object at 0x7b5ac8998920>
_aliases_for_x: ['x', 'x_k', 'xk', 'state']
_aliases_for_u: ['u', 'u_k', 'uk', 'input']
_aliases_for_t: ['t', 't_k', 'tk', 'time', 'tau']
smart_call: <bound method System.SafeIO.smart_call of <pykal_core.control_system.system.System.SafeIO object at 0x7b5ac8953710>>


In [None]:
# 1. returns_scalar – should raise TypeError
try:
    def returns_scalar(x: NDArray) -> float:
        return 3.14

    x = np.array([[1.0]])
    sys.safeio.smart_call(returns_scalar, x=x)
except TypeError as e:
    print(f"TypeError: {e}")

TypeError: In `returns_scalar`, return type must be NDArray[...] or a tuple of NDArrays, but got <class 'float'>


In [None]:
# 2. wrong_output_shape – should raise ValueError
try:
    def wrong_output_shape(x: NDArray) -> NDArray:
        return np.zeros((1, 1))

    sys.safeio.smart_call(wrong_output_shape, x=np.zeros((2, 1)), expected_shape=(2, 1))
except ValueError as e:
    print(f"ValueError: {e}")

ValueError: Output shape mismatch. Expected (2, 1), got (1, 1)


In [None]:
# 3. no_arguments
def no_arguments() -> NDArray:
    return np.ones((2, 1))

print(sys.safeio.smart_call(no_arguments))

[[1.]
 [1.]]


In [None]:
# 4. state_only
def state_only(x: NDArray) -> NDArray:
    return x + 1


x = np.array([[1.0], [2.0]])
print(sys.safeio.smart_call(state_only, x=x))

[[2.]
 [3.]]


In [None]:
# 5. input_only
def input_only(u: NDArray) -> NDArray:
    return u * 2

u = np.array([[0.5], [1.0]])
print(sys.safeio.smart_call(input_only, u=u))

[[1.]
 [2.]]


In [None]:
# 6. time_only
def time_only(t: float) -> NDArray:
    return np.array([[t], [t]])

print(sys.safeio.smart_call(time_only, t=3.0))

[[3.]
 [3.]]


In [None]:
# 7. state_input
def state_input(x: NDArray, u: NDArray) -> NDArray:
    return x + u

print(sys.safeio.smart_call(state_input, x=x, u=u))

[[1.5]
 [3. ]]


In [None]:
# 8. state_time
def state_time(x: NDArray, t: float) -> NDArray:
    return x * t

print(sys.safeio.smart_call(state_time, x=x, t=2.0))

[[2.]
 [4.]]


In [None]:
# 9. input_time
def input_time(u: NDArray, t: float) -> NDArray:
    return u + t

print(sys.safeio.smart_call(input_time, u=u, t=1.0))

[[1.5]
 [2. ]]


In [None]:
# 10. all_arguments
def all_arguments(x: NDArray, u: NDArray, t: float) -> NDArray:
    return x + u + t

print(sys.safeio.smart_call(all_arguments, x=x, u=u, t=1.0))

[[2.5]
 [4. ]]


In [None]:
# 11. reordered_arguments
def reordered_arguments(t: float, u: NDArray, x: NDArray) -> NDArray:
    return x + u + t

print(sys.safeio.smart_call(reordered_arguments, x=x, u=u, t=1.0))

[[2.5]
 [4. ]]


In [None]:
# 12. aliased_names
def aliased_names(state: NDArray, input: NDArray, time: float) -> NDArray:
    return state + input + time

print(sys.safeio.smart_call(aliased_names, x=x, u=u, t=1.0))

[[2.5]
 [4. ]]


In [None]:
# 13. with extra keyword arguments (note that u and t are not used either)
def just_x(x: NDArray) -> NDArray:
    return x * 2

print(sys.safeio.smart_call(just_x, x=x, u=u, t=2.0, extra_arg=3))

[[2.]
 [4.]]


In [None]:
def just_x_and_extra_kwarg(xk:NDArray,extra_arg:float) -> NDArray:
    return xk * extra_arg

print(sys.safeio.smart_call(just_x_and_extra_kwarg, x=x, u=u, t=2.0, extra_arg=3))

[[3.]
 [6.]]
