In [1]:
import torch

In [10]:
def FusedStep(x, I, delta_t, theta):
    """
    Update the state of the LTC neural network for a single time step using the Fused ODE Solver.

    Args:
        x (torch.Tensor): Current state of the neural network, shape (N, 1).
        I (torch.Tensor): Input for the current time step, shape (M, 1).
        delta_t (float): Step size.
        theta (tuple): Tuple containing parameters: tau (time-constant), gamma (weights), r (recurrent weights), mu (biases).
        
    Returns:
        torch.Tensor: Updated state of the neural network, shape (N, 1).
    """
    # Unpack parameters
    tau, gamma, r, mu = theta
    A = torch.ones_like(mu)

    # Compute f(x(t), I(t), t, θ)
    f_x = torch.tanh(torch.matmul(gamma, I) + torch.matmul(r, x) + mu)

    # Compute x(t + Δt)
    x_next = x + delta_t * (f_x * A * (1 + delta_t / tau))


    return x_next

In [11]:
def LTC_update_by_fused_ODE_Solver(theta, A, L, delta_t, input_sequence, initial_state):
    """
    Perform LTC update using the Fused ODE Solver.

    Args:
        theta (tuple): Tuple containing parameters: tau (time-constant), gamma (weights), r (recurrent weights), mu (biases).
        A (torch.Tensor): Bias vector, shape (N, 1).
        L (int): Number of unfolding steps.
        delta_t (float): Step size.
        input_sequence (torch.Tensor): Input sequence, shape (M, L).
        initial_state (torch.Tensor): Initial state of the neural network, shape (N, 1).
        
    Returns:
        torch.Tensor: Next state of the neural network after L unfolding steps, shape (N, 1).
    """
     # Initialize current state
    x = initial_state

    # Perform L unfolding steps
    for i in range(L):
        x = FusedStep(x, input_sequence[:, i], delta_t, theta)

    return x