## Python code

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, lax, vmap
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
from collections import deque

# ===============================
# 1. Parameters & Setup
# ===============================
N, num_agents = 100, 10
dt = 1.0
Gamma, Gamma_all = 2, 1 
alpha0, tol = 1.0, 1e-4
mu_meet = 0.25          
d_min = 0.01            
w_v, w_w = 1.0, 1.0   
maxIter, maxLineIters = 5000, 50000
armijo_beta, armijo_c = 0.5, 0.5 
stall_window, stall_tol = 10, 1e-1 
stall_window_rho, stall_tol_rho = 10, 1e-3 
lam, global_iter = 1, 0 
h1, h2 = 1000, 1000  # H = diag(h1, h2) 
h_diag = jnp.array([h1, h2])


weights = jnp.array([w_v**2, w_w**2])
agent_range = jnp.arange(num_agents)

x0_list = jnp.array([
    [1.5, 10, 0], [1.5, 15, 0], [11.5, 10, 0], [13.5, 15, 0], [17, 7, 0],
    [31.5, 15, 0], [33.5, 10, 0], [36.5, 6, 0], [47.0, 10, 0], [47.0, 15, 0]
])

O_centers = jnp.array([[6.0, 10.0], [24.0, 10.0], [42.0, 10.0]])
r_obs = jnp.array([3.8, 3.8, 3.8])
T_centers = jnp.array([jnp.array([1.5, 1.5]) + k*jnp.array([5.0, 0.0]) for k in range(num_agents)])
G_centers = jnp.array([jnp.array([1.5, 19.0]) + k*jnp.array([5.0, 0.0]) for k in range(num_agents)])
r_target, r_goal = jnp.full(num_agents, 0.8), jnp.full(num_agents, 0.8)

meeting_groups_mat = jnp.array([
    [0, 1, 2], [2, 3, -1], [0, 4, -1], [3, 4, -1], [3, 6, -1],
    [4, 5, -1], [6, 7, -1], [5, 7, -1], [5, 8, -1], [8, 9, -1], [7, 9, -1]
])

t1_low, t1_high = 11, 51
g1_low, g1_high = max(1, N-30), N
meet_low, meet_high = 0, 70
# --- Until Transition Window ---
until_low, until_high = 1, 51

# --- Phase 1: Reach Ci (Relative to current time t) ---
# Goal: Visit Ci within 10 to 50 steps from "now"
ci_rel_low, ci_rel_high = 11, 51 

# --- Phase 2: Reach Di (Relative to transition time t*) ---
# Goal: Visit Di within 10 to 50 steps from the switch
di_rel_low, di_rel_high = N-30, N 


# ===============================
# 2. Differentiable Smooth Semantics
# ===============================
def smooth_min(vec, G):
    m = jnp.min(vec)
    return -(1/G) * (jnp.log(jnp.sum(jnp.exp(-G * (vec - m)))) - G * m)

#def smooth_max(vec, G):
#    m = jnp.max(vec)
#    return (1/G) * (jnp.log(jnp.sum(jnp.exp(G * (vec - m))) + 1e-6) + G * m)

def smooth_max(vec, G):
    m = jnp.max(vec)
    n = vec.shape[0]
    # Subtracting log(n)/G makes it a guaranteed under-approximation
    return (1/G) * (jnp.log(jnp.sum(jnp.exp(G * (vec - m)))) + G * m) - (jnp.log(n)/G)

def step_unicycle(carry, u):
    x, y, th = carry
    next_state = jnp.array([x + dt*u[0]*jnp.cos(th), y + dt*u[0]*jnp.sin(th), th + dt*u[1]])
    return next_state, carry 

def simulate_agent_scan(x0, u_seq):
    last_state, states = lax.scan(step_unicycle, x0, u_seq.T)
    return jnp.vstack([states, last_state]).T

@jit
def total_rho(U, G, Gall):
    """
    Computes the smooth robustness for the RURAMCA specification:
    (Eventual Ci) U_[0,50] (Eventual Di) AND (Collaborative Meetings)
    """
    # 1. Simulate trajectories for all agents
    # Xall shape: (num_agents, state_dim, N+1)
    Xall = vmap(simulate_agent_scan)(x0_list, U)
    t_idx = jnp.arange(N + 1)
    
    def agent_metrics(idx, Tc, Gc, rt, rg):
        # Extract position sequence for agent idx: (N+1, 2)
        pos = Xall[idx, :2, :].T 
        
        # --- A. Safety Constraints (Obstacles & Inter-agent) ---
        # Obstacles
        d_obs = jnp.sqrt(jnp.sum((pos[:, None, :] - O_centers[None, :, :])**2, axis=-1) + 1e-9) - r_obs
        m_obs = vmap(lambda v: smooth_min(v, G))(d_obs)
        
        # Inter-agent Collision Avoidance
        all_agent_pos_T = Xall[:, :2, :].transpose(2, 0, 1) # (N+1, num_agents, 2)
        dist_agents = jnp.sqrt(jnp.sum((pos[:, None, :] - all_agent_pos_T)**2, axis=-1) + 1e-9)
        # Mask self-collision
        dist_agents_masked = jnp.where((agent_range == idx)[None, :], 1e5, dist_agents - d_min)
        m_agents = vmap(lambda v: smooth_min(v, G))(dist_agents_masked)

        # --- B. Sequential Until Logic ---
        # Robustness to regions at every time step
        dist_to_Ci = jnp.sqrt(jnp.sum((pos - Tc)**2, axis=-1) + 1e-9)
        dist_to_Di = jnp.sqrt(jnp.sum((pos - Gc)**2, axis=-1) + 1e-9)

        # Pre-calculate Eventually robustness for all possible starting times t
        def get_eventual_Ci(t):
            mask = (t_idx >= t + ci_rel_low) & (t_idx <= t + ci_rel_high)
            return smooth_max(jnp.where(mask, rt - dist_to_Ci, -1e5), G)

        def get_eventual_Di(t):
            mask = (t_idx >= t + di_rel_low) & (t_idx <= t + di_rel_high)
            return smooth_max(jnp.where(mask, rg - dist_to_Di, -1e5), G)

        rho_Ci_at_t = vmap(get_eventual_Ci)(t_idx)
        rho_Di_at_t = vmap(get_eventual_Di)(t_idx)

        # Compute Until robustness for a specific transition time t_star
        def until_at_tstar(t_star):
            # 1. Right-hand side (Switch to Di goal)
            rhs = rho_Di_at_t[t_star]
            # 2. Left-hand side (Maintain Ci promise up to t_star)
            # Use a soft-min over the history of rho_Ci
            lhs_history = jnp.where(t_idx <= t_star, rho_Ci_at_t, 1e5)
            lhs = smooth_min(lhs_history, G)
            return smooth_min(jnp.array([rhs, lhs]), G)

        # Maximize over the transition window [0, 50]
        u_window = jnp.arange(until_low, until_high + 1)
        m_until = smooth_max(vmap(until_at_tstar)(u_window), G)
        
        # Conjunction of Safety and Sequential Task
        return smooth_min(jnp.array([
            smooth_min(m_obs, G), 
            smooth_min(m_agents, G), 
            m_until
        ]), G)

    # Compute robustness for each agent independently
    rhos_agents = vmap(agent_metrics)(agent_range, T_centers, G_centers, r_target, r_goal)
    
    # --- C. Collaborative Meeting Tasks ---
    def compute_group_rho(member_indices):
        mask = (member_indices != -1)
        # Use index 0 as a safe dummy for masked agents
        safe_idx = jnp.where(mask, member_indices, 0)
        group_pos = Xall[safe_idx, :2, :] 
        
        # Pairwise distances within the group
        diff = group_pos[:, None, :, :] - group_pos[None, :, :, :]
        dists = jnp.sqrt(jnp.sum(diff**2, axis=2) + 1e-9) 
        
        # Only consider valid pairs (no self-distance, no masked agents)
        pair_mask = mask[:, None] * mask[None, :] * (jnp.eye(3) == 0)
        
        # Max distance between any two members at time t (lower is better for meeting)
        max_dist_at_t = jnp.max(jnp.where(pair_mask[:, :, None], dists, 0.0), axis=(0, 1))
        rho_at_t = mu_meet - max_dist_at_t
        
        # Check if meeting happens at any point in the meet_window
        rho_window = jnp.where((t_idx >= meet_low) & (t_idx <= meet_high), rho_at_t, -1e5)
        return smooth_max(rho_window, G)

    # Compute robustness for each group
    rhos_groups = vmap(compute_group_rho)(meeting_groups_mat)
    
    # --- D. Global Conjunction ---
    return smooth_min(jnp.concatenate([rhos_agents, rhos_groups]), Gall)

@jit
def compute_full_objective(U, lam, G, Gall):
    rho = total_rho(U, G, Gall)
    control_cost = jnp.sum(weights[:, None] * U**2)
    penalty = lam * jnp.maximum(0.0, -rho)**2
    return control_cost + penalty

@jit
def compute_control_cost(U):
    control_cost = jnp.sum(weights[:, None] * U**2)
    return control_cost

@jit
def get_block_data(U, lam, G, Gall):
    J = compute_full_objective(U, lam, G, Gall)
    rho = total_rho(U, G, Gall)
    g_J = grad(compute_full_objective)(U, lam, G, Gall)
    g_rho = grad(total_rho)(U, G, Gall)
    full_grad_R = -2 * jnp.maximum(0.0, -rho) * g_rho
    return J, g_J, rho, full_grad_R

# ===============================
# 3. BCGD Optimization Loop
# ===============================
U = jnp.zeros((num_agents, 2, N))
rho_hist, cost_hist = [], []
cost_window = deque(maxlen=stall_window) 
rho_window = deque(maxlen=stall_window_rho)

print(f"{'Iter X 10':<8} | {'rho (smooth)':<12} | {'rho (true)': <12} | {'Objective':<12} | {'lam':<8}")   # add total_rho(U,1000)
print("-" * 60)

H_inv = jnp.array([1.0/(2*(w_v**2) + h1*lam), 1.0/(2*(w_w**2) + h2*lam)]) 

try:
    for out_it in range(25):
        cost_window.clear()
        if total_rho(U, 1e6, 1e6)>0: break
        for epoch in range(maxIter):
            indices = np.random.permutation(num_agents)
            #indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

            for blk in indices:
                J_base, full_grad, rho_curr, full_grad_R = get_block_data(U, lam, Gamma, Gamma_all)
                L_base = compute_control_cost(U)
                g_blk = full_grad[blk]   
                g_R_blk = full_grad_R[blk] 
                U_old_blk = U[blk]
                
                # Update direction using your rule
                d_blk = -(2*U_old_blk * weights[:, None] + lam * g_R_blk) * H_inv[:, None]
                L_trial = compute_control_cost(U_old_blk + 1 * d_blk)
                
                # normalize update direction
                max_step = 0.05
                d_norm = jnp.linalg.norm(d_blk)
                d_blk = jnp.where(d_norm < max_step, d_blk * (max_step / d_norm), d_blk)
                slope = jnp.sum(g_blk * d_blk)
                
                # Line Search (Armijo)
                alpha = alpha0
                for ls_it in range(maxLineIters):
                    U_trial = U.at[blk].set(U_old_blk + alpha * d_blk)
                    J_trial = compute_full_objective(U_trial, lam, Gamma, Gamma_all)

                    σ = 0.5
                    γ = 0.995 
                    Delta_k = lam*(jnp.sum(g_R_blk * d_blk)+γ*jnp.sum(h_diag[:, None] * (d_blk**2)))+L_trial-L_base
                    # Choose Armijo rule below:
                    #if J_trial <= J_base + σ * alpha * slope:
                    #if J_trial-J_base <= alpha*σ*(γ-1)*jnp.sum(h_diag[:, None] * (d_blk**2)):
                    if J_trial-J_base <= σ*alpha*Delta_k:
                        U, J_base, L_base = U_trial, J_trial, L_trial
                        break
                    alpha *= armijo_beta
                

            rho_hist.append(float(rho_curr))
            rho_window.append(float(rho_curr))
            cost_hist.append(float(J_trial)) 
            cost_window.append(float(J_trial)) 

            if len(cost_window) == stall_window:
                rel_change = abs(cost_window[0] - cost_window[-1]) #/ (abs(cost_window[0]) + 1e-9)
                if rel_change < stall_tol: break

            global_iter += 1
            if global_iter % 10 == 0:
                print(f"{global_iter:<8} | {rho_curr:>12.4f} | {total_rho(U,1e6,1e6):>12.4f} |  {float(J_trial):>12.2f} | {lam:>8.1f}") # total_rho(U,1000)

            if rho_curr >= -tol or total_rho(U,1e6,1e6)>0: 
                print("Successfully Satisfied with true robustness ρ =",total_rho(U, 1e6, 1e6) ); break
        
        # increase lambda    
        lam = min(lam * 5, 1e10)

except KeyboardInterrupt:
    print("Optimization interrupted.")

# ===============================
# 4. Generate trajectories
# ===============================
Xall_RURAMCA = np.array(vmap(simulate_agent_scan)(x0_list, U))

## Animation code

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import matplotlib.patches as patches
from IPython.display import Image, display

# 1. DEFINE CLIQUE STRUCTURE 
meeting_groups_mat = [
    [0, 1, 2], [2, 3], [0, 4], [3, 4], [3, 6],
    [4, 5], [6, 7], [5, 7], [5, 8], [8, 9], [7, 9]
]

def get_automated_meeting_configs(Xall, cliques):
    """Detects the optimal (x,y) for each clique based on minimal dispersion."""
    configs = []
    for members in cliques:
        group_traj = Xall[members, :2, :]
        max_dist_at_t = []
        for t in range(Xall.shape[2]):
            pos_t = group_traj[:, :, t]
            diff = pos_t[:, None, :] - pos_t[None, :, :]
            dists = np.sqrt(np.sum(diff**2, axis=-1))
            max_dist_at_t.append(np.max(dists))
        
        t_meet = np.argmin(max_dist_at_t)
        meet_pos = np.mean(group_traj[:, :, t_meet], axis=0)
        label = f"R{'-R'.join([str(m+1) for m in members])} met"
        configs.append({'pos': meet_pos, 'label': label, 'members': members})
    return configs

def animate_with_dynamic_meetings(Xall, filename="RURAMCA_unicycle_final.gif"):
    num_agents, _, total_steps = Xall.shape
    tikz_colors = [(0.066, 0.443, 0.745), (0.866, 0.329, 0.000), (0.929, 0.694, 0.125),
                   (0.521, 0.086, 0.819), (0.231, 0.666, 0.196), (0.184, 0.745, 0.937), (0.819, 0.015, 0.545)]
    warmorange = (1, 0.6, 0)
    
    # Generate automated meeting data
    raw_configs = get_automated_meeting_configs(Xall, meeting_groups_mat)
    
    fig, ax = plt.subplots(figsize=(12, 6))
    ax.set_xlim(-1, 49); ax.set_ylim(-1, 21); ax.set_aspect('equal')
    ax.set_xlabel('z position'); ax.set_ylabel('y position')

    # Draw Obstacles and Target Regions
    obs_coords = [2.85, 20.85, 38.55] 
    for i, x_s in enumerate(obs_coords):
        ax.add_patch(patches.Rectangle((x_s, 8.6), 6.3, 2.8, facecolor=warmorange, alpha=0.9, edgecolor='black'))
        ax.text(x_s + 3.15, 10, f"☐ ¬O{i+1}", ha='center', va='center', fontsize=10, fontweight='bold')

    for i in range(10):
        c = tikz_colors[i % len(tikz_colors)]
        cx = 1.5 + (i * 5)
        # Diamond C_i
        ax.add_patch(patches.RegularPolygon((cx, 1.5), 4, radius=1.5, facecolor=c, alpha=0.2, edgecolor='black'))
        ax.text(cx, -0.3, f"R{i+1} Collects", ha='center', va='top', fontsize=8)
        # Square D_i
        ax.add_patch(patches.Rectangle((cx-1, 18), 2, 2, facecolor=c, alpha=0.35, edgecolor='black'))
        ax.text(cx, 20.3, f"R{i+1} Delivers", ha='center', va='bottom', fontsize=8)

    # Initialize Meeting Circles and Text with COLLISION AVOIDANCE
    meeting_patches = []
    meeting_texts = []
    
    # Sort configs by Y position to handle stacking
    sorted_configs = sorted(raw_configs, key=lambda x: x['pos'][1])
    
    placed_label_positions = [] # To track where we've put labels

    for cfg in sorted_configs:
        pos = cfg['pos']
        p = plt.Circle(pos, 0.7, color='green', alpha=0, zorder=4)
        
        # Logic to offset labels if points are too close
        y_offset = 1.1
        for prev_pos in placed_label_positions:
            if np.linalg.norm(pos - prev_pos) < 2.0: # Close proximity threshold
                y_offset += 1.0 # Stack higher
        
        t = ax.text(pos[0], pos[1] + y_offset, cfg['label'], 
                    fontsize=8, ha='center', alpha=0, fontweight='bold',
                    bbox=dict(facecolor='white', alpha=0, edgecolor='none', pad=0.5))
        
        placed_label_positions.append(pos)
        ax.add_patch(p)
        meeting_patches.append(p)
        meeting_texts.append(t)

    # Initialize Agents
    lines = [ax.plot([], [], lw=1.3, alpha=0.7, color=tikz_colors[i % 7])[0] for i in range(num_agents)]
    for i in range(num_agents):
        ax.plot(Xall[i, 0, 0], Xall[i, 1, 0], 's', color=tikz_colors[i % 7], markersize=4, alpha=0.4)

    # Agent Elements
    # Start positions (Squares)
    for i in range(num_agents):
        ax.plot(Xall[i, 0, 0], Xall[i, 1, 0], 's', color=tikz_colors[i % len(tikz_colors)], markersize=6, markeredgecolor='black')
    # lines = [ax.plot([], [], lw=1.5, color=tikz_colors[i%7])[0] for i in range(num_agents)]
    lines = [ax.plot([], [], lw=1.5, alpha=0.8, color=tikz_colors[i % len(tikz_colors)])[0] for i in range(num_agents)]
    heads = ax.quiver(Xall[:,0,0], Xall[:,1,0], np.cos(Xall[:,2,0]), np.sin(Xall[:,2,0]), 
                      color=[tikz_colors[i%7] for i in range(num_agents)], scale=35, pivot='mid')


    def update(frame):
        # Update Quivers
        # Update Agents
        X, Y = Xall[:, 0, frame], Xall[:, 1, frame]
        heads.set_offsets(np.stack([X, Y], axis=1))
        heads.set_UVC(np.cos(Xall[:, 2, frame]), np.sin(Xall[:, 2, frame]))
        for i in range(num_agents):
            lines[i].set_data(Xall[i, 0, :frame], Xall[i, 1, :frame])
        
        # Update Paths (frame+1 ensures head and tail are connected)
        for i in range(num_agents):
            lines[i].set_data(Xall[i, 0, :frame+1], Xall[i, 1, :frame+1])

        # Update Meeting Visibility
        threshold = 1.2
        for idx, cfg in enumerate(sorted_configs):
            agents_involved = cfg['members']
            pos = cfg['pos']
            dists = [np.linalg.norm(Xall[a, :2, frame] - pos) for a in agents_involved]
            
            if all(d < threshold for d in dists):
                meeting_patches[idx].set_alpha(0.3) 
                meeting_texts[idx].set_alpha(1.0)
                # Background white halo for readability
                meeting_texts[idx].get_bbox_patch().set_alpha(0.6)
            
        return lines + [heads] + meeting_patches + meeting_texts

    ani = FuncAnimation(fig, update, frames=range(0, total_steps, 1), blit=True, interval=50)
    ani.save(filename, writer=PillowWriter(fps=20))
    plt.close()
    display(Image(filename=filename))

# Execute
animate_with_dynamic_meetings(Xall_RURAMCA) 