In [65]:
import numpy as np
from scipy.linalg import solve_discrete_are
from scipy.integrate import solve_ivp

In [66]:
class ComplexRoad:
    def __init__(self, structure: str, gammas: np.ndarray, cs: np.array, betas: np.ndarray, alpha: float):
        # Validate inputs
        for i in structure:
            if i not in {'n', 'm', 'r', 'l'}:
                raise ValueError(f"Unknown segment type {i}")
        if structure.find("n") != 0 or structure.find("n", 1) != -1:
            raise ValueError("Cars must be added at the start")
        if structure.find("l") != len(structure) - 1:
            raise ValueError("Cars must leave at the end")
        if len(gammas) != len(structure) - 1:
            raise ValueError(f"Must have {len(structure) - 1} gamma values for {len(structure)} road segments")
        self._m_count = sum(np.array(list(structure)) == "m")
        if self._m_count != len(cs):
            raise ValueError(f"Must have a carrying capacity for each m segment")
        if self._m_count != len(betas):
            raise ValueError(f"Must have a beta for each m segment")
        
        # Store incoming parameters
        self.structure = np.array(list(structure))
        self.betas = betas
        self.gammas = gammas
        self.cs = cs
        self.alpha = alpha
        
        # Determine the structure layout
        i = 0
        self._i_alpha    = list(range(i, i := i + 1))
        self._i_betas    = list(range(i, i := i + len(betas)))
        self._i_kappas   = list(range(i, i := i + len(betas)))
        self._i_segments = []
        self._i_queues   = []
        
        # Determine assignments for each road portion
        i -= 1
        for seg in structure:
            if seg == "m":
                self._i_queues.append(i := i + 1)
            self._i_segments.append(i := i + 1)
        
        # Determine the total number of entries
        self._n_entries = i + 1
        self._n_queues = len(self._i_queues)
        self._n_segments = len(self._i_segments)
        
        # Compute the default costs
        self.state_costs = None
        self.control_costs = None
        self.reconstruct_costs()
        
    def reconstruct_costs(self, q_penalty=4., n_penalty=3., m_penalty=4., r_penalty=2., l_penalty=0., seg_penalty=None, u_penalty=1.):
        # Compute the cost terms
        costs = [0.] * self._n_entries
        for i in range(self._n_entries):
            # If this is a queue, use the cost for all queues, or its individual cost if specified
            if i in self._i_queues:
                if isinstance(q_penalty, (int, float)):
                    costs[i] = q_penalty
                else:
                    costs[i] = q_penalty[self._i_queues.index(i)]
                    
            # If this is a segment, use the cost for the segment based on its type, or its individual cost if specified
            elif i in self._i_segments:
                if seg_penalty is None:
                    seg_type = self.structure[self._i_segments.index(i)]
                    if seg_type == "m":
                        costs[i] = m_penalty
                    elif seg_type == "n":
                        costs[i] = n_penalty
                    elif seg_type == "r":
                        costs[i] = r_penalty
                    elif seg_type == "l":
                        costs[i] = l_penalty
                else:
                    costs[i] = seg_penalty[self._i_segments.index(i)]
            
        # Store the computed costs
        self.state_costs = costs
        self.control_costs = u_penalty if not isinstance(u_penalty, (int, float)) else np.ones(self._m_count) * u_penalty
        
    def _get_evolution(self, merge_init: np.ndarray):
        # Create a matrix of zeros so everything is constant by default
        A = np.zeros((self._n_entries, self._n_entries))
        B = np.zeros((self._n_entries, self._m_count))
        
        # Loop through each segment and configure it appropriately
        merge_index = 0
        for seg_i, seg in enumerate(self.structure[:-1]):
            # Get this road segment's index and the next index, and the transition rate
            seg_mat_i = self._i_segments[seg_i]
            gamma = self.gammas[seg_i]
            nxt_seg_mat_i = self._i_segments[seg_i + 1]
            
            # Handle input segments
            if seg == "n":
                A[seg_mat_i, self._i_alpha] = 1
                # Configure term in current segment row
                A[seg_mat_i, seg_mat_i] = -gamma
                # Configure term in next segment row
                A[nxt_seg_mat_i, seg_mat_i] = gamma
            
            # Handle road segments
            elif seg == "r":
                # Configure term in current segment row
                A[seg_mat_i, seg_mat_i] = -gamma
                # Configure term in next segment row
                A[nxt_seg_mat_i, seg_mat_i] = gamma
            
            # Handle merge segments
            else:
                # Compute the linearization term
                lt = gamma + (2 * gamma * merge_init[merge_index]) / self.cs[merge_index]
                
                # Configure terms in the queue
                A[self._i_queues[merge_index], self._i_betas[merge_index]] = 1
                B[self._i_queues[merge_index], merge_index] = -1
                
                # Configure terms in merge row
                A[seg_mat_i, self._i_kappas[merge_index]] = -1
                A[seg_mat_i, seg_mat_i] = -lt
                B[seg_mat_i, merge_index] = 1
                
                # Configure terms in next segment
                A[nxt_seg_mat_i, self._i_kappas[merge_index]] = 1
                A[nxt_seg_mat_i, seg_mat_i] = lt
                
                # Increment merge index
                merge_index += 1
        
        # Return the computed matrices
        return A, B
    
    def _get_costs(self):
        return np.diag(self.state_costs), np.diag(self.control_costs)
    
    def single_step(self, init_roads, init_queues, time_span, r_inv = None):
        # Compute each of the kappas
        kappas = np.array([1])
        
        # Construct the initial state vector
        init_state = np.zeros(self._n_entries)
        init_state[self._i_alpha] = self.alpha
        init_state[self._i_betas] = self.betas
        init_state[self._i_kappas] = kappas
        init_state[self._i_queues] = init_queues
        init_state[self._i_segments] = init_roads
        
        # Get A, B, Q, and R
        A, B = self._get_evolution(init_roads[self.structure == "m"])
        Q, R = self._get_costs()
        r_inv = np.linalg.inv(R) if r_inv is None else r_inv
        
        # Use the algebraic Ricatti equation to find P
        P = solve_discrete_are(A, B, Q, R)
        
        # Set up the evolution equation with the optimal control
        def _system(t, y):
            return (A - B @ r_inv @ B.T @ P) @ y
        
        # Solve the optimal state evolution using the DOP853 solver
        sol = solve_ivp(_system, time_span, init_state, dense_output=True, method="DOP853")
        def _get_sol(t):
            res = sol.sol(t)
            return res[self._i_segments], res[self._i_queues]
        
        # Return the found solution
        return _get_sol, lambda t: (-r_inv @ B.T @ P @ sol.sol(t).reshape(-1, 1))
        
    def multi_step(self, init_roads, init_queues, time_span, num_intervals = 10):
        # Solver parameters
        _, R = self._get_costs()
        R_inv = np.linalg.inv(R)
        n_count = 1001
        time_intervals = np.linspace(*time_span, num_intervals + 1)
        
        # Create variables to store the found states
        total_entries = n_count * num_intervals
        roads = np.zeros((self._n_segments, total_entries))
        queues = np.zeros((self._n_queues, total_entries))
        control = np.zeros((self._n_queues, total_entries))
        
        # Loop through each interval and find the values
        for i in range(len(time_intervals) - 1):
            # Get the interval parameters
            interval = time_intervals[i,i+1]
            t_space = np.linspace(*interval, n_count)
            i1 = n_count + (i0 := i * n_count)
            
            # Solve on the interval, storing the states & controls
            sol_poly, ctrl_poly = self.single_step(init_roads, init_queues, interval, r_inv=R_inv)
            roads[:,i0:i1], queues[:,i0:i1] = sol_poly(t_space)
            control[:,i0:i1] = ctrl_poly(t_space)
            
            # Update the initial conditions for the next interval
            init_roads, init_queues = roads[:,i1], queues[:,i1]
            
        # Return the computed solutions
        return roads, queues, control


In [68]:
m = ComplexRoad(
    "nmrrrmmmmrrrl",
    np.array([1.,2.,3.,4.,5.,6.,7.,8., 9., 10., 11., 12.]),
    np.array([1.,1., 1., 1., 1.]),
    np.array([1.7, 3., 2.7, 1., 1.]),
    1.
)

# A, B = m._get_matrix(np.array([0., 0., 0., 0., 0.]))
# A.astype(int)