## Python code

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, lax, vmap
from jaxopt import LBFGS
import numpy as np
from collections import deque

# ===============================
# 1. Parameters & Setup
# ===============================
N, num_agents = 100, 10
dt = 1.0
Gamma, Gamma_all = 2, 2
mu_meet = 0.25          
d_min = 0.01            
w_v, w_w = 1.0, 1.0   
stall_window, stall_tol = 10, 1e-1 
lam, eta_lam = 1.0, 10.0 
maxIter = 5000

weights = jnp.array([w_v**2, w_w**2])
agent_range = jnp.arange(num_agents)

x0_list = jnp.array([
    [1.5, 10], [1.5, 15], [11.5, 10], [13.5, 15], [17, 7],
    [31.5, 15], [33.5, 10], [36.5, 6], [47.0, 10], [47.0, 15]
])

O_centers = jnp.array([[6.0, 10.0], [24.0, 10.0], [42.0, 10.0]])
r_obs = jnp.array([3.8, 3.8, 3.8])
T_centers = jnp.array([jnp.array([1.5, 1.5]) + k*jnp.array([5.0, 0.0]) for k in range(num_agents)])
G_centers = jnp.array([jnp.array([1.5, 19.0]) + k*jnp.array([5.0, 0.0]) for k in range(num_agents)])
r_target, r_goal = jnp.full(num_agents, 0.8), jnp.full(num_agents, 0.8)

meeting_groups_mat = jnp.array([
    [0, 1, 2], [2, 3, -1], [0, 4, -1], [3, 4, -1], [3, 6, -1],
    [4, 5, -1], [6, 7, -1], [5, 7, -1], [5, 8, -1], [8, 9, -1], [7, 9, -1]
])

t1_low, t1_high = 11, 51
g1_low, g1_high = max(1, N-30), N
meet_low, meet_high = 0, 70

# ===============================
# 2. Dynamics & Robustness
# ===============================
def smooth_min(vec, G):
    m = jnp.min(vec)
    return -(1/G) * (jnp.log(jnp.sum(jnp.exp(-G * (vec - m)))) - G * m)

def smooth_max(vec, G):
    m = jnp.max(vec)
    n = vec.shape[0]
    return (1/G) * (jnp.log(jnp.sum(jnp.exp(G * (vec - m)))) + G * m) - (jnp.log(n)/G)

def step_linear(carry, u):
    # FIXED: Return next_state as both carry and output for correct trajectory accumulation
    next_state = carry + dt * u
    return next_state, next_state 

def simulate_agent_scan(x0, u_seq):
    # u_seq is (2, N)
    _, states = lax.scan(step_linear, x0, u_seq.T)
    return jnp.vstack([x0, states]).T # (2, N+1)

@jit
def total_rho(U, G, Gall):
    """
    Computes smooth robustness for:
    (Eventually Ci) AND (Eventually Di) AND (Obstacle Avoidance) AND (Meetings)
    Note: Inter-agent collision avoidance has been removed.
    """
    Xall = vmap(simulate_agent_scan)(x0_list, U)
    t_idx = jnp.arange(N + 1)
    
    def agent_metrics(idx, Tc, Gc, rt, rg):
        # Extract position sequence for agent idx: (N+1, 2)
        pos = Xall[idx, :2, :].T 
        
        # --- A. Static Obstacle Avoidance (KEPT) ---
        # Distance to obstacles: (N+1, num_obstacles)
        d_obs = jnp.sqrt(jnp.sum((pos[:, None, :] - O_centers[None, :, :])**2, axis=-1) + 1e-9) - r_obs
        # smooth_min over time and obstacles ensures we stay out of all orange zones
        m_obs = vmap(lambda v: smooth_min(v, G))(d_obs)
        
        # --- B. Reachability (Eventually) ---
        dist_to_Ci = jnp.sqrt(jnp.sum((pos - Tc)**2, axis=-1) + 1e-9)
        dist_to_Di = jnp.sqrt(jnp.sum((pos - Gc)**2, axis=-1) + 1e-9)

        # Eventually Ci in window [t1_low, t1_high]
        mask_ci = (t_idx >= t1_low) & (t_idx <= t1_high)
        rho_eventually_Ci = smooth_max(jnp.where(mask_ci, rt - dist_to_Ci, -1e5), G)

        # Eventually Di in window [g1_low, g1_high]
        mask_di = (t_idx >= g1_low) & (t_idx <= g1_high)
        rho_eventually_Di = smooth_max(jnp.where(mask_di, rg - dist_to_Di, -1e5), G)
        
        # Conjunction of Obstacle Safety and Reachability Goals
        return smooth_min(jnp.array([
            smooth_min(m_obs, G), 
            rho_eventually_Ci,
            rho_eventually_Di
        ]), G)

    # Compute robustness for each agent independently
    rhos_agents = vmap(agent_metrics)(agent_range, T_centers, G_centers, r_target, r_goal)
    
    # --- C. Collaborative Meeting Tasks ---
    def compute_group_rho(member_indices):
        mask = (member_indices != -1)
        safe_idx = jnp.where(mask, member_indices, 0)
        group_pos = Xall[safe_idx, :2, :] 
        
        diff = group_pos[:, None, :, :] - group_pos[None, :, :, :]
        dists = jnp.sqrt(jnp.sum(diff**2, axis=2) + 1e-9) 
        
        pair_mask = mask[:, None] * mask[None, :] * (jnp.eye(3) == 0)
        max_dist_at_t = jnp.max(jnp.where(pair_mask[:, :, None], dists, 0.0), axis=(0, 1))
        rho_at_t = mu_meet - max_dist_at_t
        
        rho_window = jnp.where((t_idx >= meet_low) & (t_idx <= meet_high), rho_at_t, -1e5)
        return smooth_max(rho_window, G)

    rhos_groups = vmap(compute_group_rho)(meeting_groups_mat)
    
    # --- D. Global Conjunction ---
    return smooth_min(jnp.concatenate([rhos_agents, rhos_groups]), Gall)

# ===============================
# 3. Solver Setup
# ===============================
@jit
def compute_full_objective(U, lam, G, Gall):
    rho = total_rho(U, G, Gall)
    control_cost = jnp.sum(weights[:, None] * U**2)
    # Quadratic penalty for robustness violation
    penalty = lam * jnp.maximum(0.0, -rho)**2
    return control_cost + penalty

@jit
def block_objective(u_i_flat, agent_idx, U_full, lam):
    u_i = u_i_flat.reshape(2, N)
    U_trial = U_full.at[agent_idx].set(u_i)
    return compute_full_objective(U_trial, lam, Gamma, Gamma_all)

lbfgs_block = LBFGS(fun=block_objective, maxiter=1, jit=True)

# ===============================
# 4. Optimization Loop
# ===============================
# FIXED: Initial small noise to provide gradient signal in linear case
key = jax.random.PRNGKey(42)
U = jax.random.normal(key, (num_agents, 2, N)) * 0.0

cost_window = deque(maxlen=stall_window)
block_states = [lbfgs_block.init_state(U[i].flatten(), agent_idx=i, U_full=U, lam=lam) 
                for i in range(num_agents)]

print(f"{'Outer':<6} | {'Inner':<6} | {'rho(smooth)':<12} | {'rho(true)':<12} | {'Objective':<12}| {'lam':<12}")
print("-" * 70)

try:
    for out_it in range(15):
        cost_window.clear() # Clear at start of each penalty level
        
        for inner_it in range(maxIter):
            # 1. Update ALL agents (one full epoch)
            indices = np.random.permutation(num_agents)
            for blk in indices:
                u_init = U[blk].flatten()
                new_u_flat, block_states[blk] = lbfgs_block.update(
                    u_init, block_states[blk], agent_idx=blk, U_full=U, lam=lam
                )
                U = U.at[blk].set(new_u_flat.reshape(2, N))

            # 2. Check progress AFTER all agents have moved (Out of blk loop)
            loss = float(compute_full_objective(U, lam, Gamma, Gamma_all))
            rho_s = float(total_rho(U, Gamma, Gamma_all))
            rho_t = float(total_rho(U, 1e6, 1e6))
            
            # 3. Correct Stall Detection
            cost_window.append(loss)
            if len(cost_window) == stall_window:
                if abs(cost_window[0] - cost_window[-1]) < stall_tol:
                    print(f"Stall at inner {inner_it}. Increasing lam...")
                    break # This now correctly exits inner_it loop to out_it loop

            # 4. Success Check
            if rho_t > 0:
                print(f"Success! Final True Robustness: {rho_t:.4f}")
                break # Exits inner_it

            if inner_it % 10 == 0:
                print(f"{out_it:<4} | {inner_it:<4} | {rho_s:>8.4f} | {rho_t:>8.4f} | {loss:>10.2f} | {lam:<8}")

        # If success, break the outer loop too
        if float(total_rho(U, 1e6, 1e6)) > 0: 
            break
        
        lam = min(lam * eta_lam, 1e9)

except KeyboardInterrupt:
    print("Optimization interrupted.")

# Final result
Xall_R2AM_LBFGS_linear = np.array(vmap(simulate_agent_scan)(x0_list, U))

## Animation code

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import matplotlib.patches as patches
from IPython.display import Image, display

# 1. DEFINE CLIQUE STRUCTURE 
meeting_groups_mat = [
    [0, 1, 2], [2, 3], [0, 4], [3, 4], [3, 6],
    [4, 5], [6, 7], [5, 7], [5, 8], [8, 9], [7, 9]
]

def get_automated_meeting_configs(Xall, cliques):
    """Detects the optimal (x,y) for each clique based on minimal dispersion."""
    configs = []
    for members in cliques:
        # Slicing :2 because we only care about x,y for distance
        group_traj = Xall[members, :2, :]
        max_dist_at_t = []
        for t in range(Xall.shape[2]):
            pos_t = group_traj[:, :, t]
            diff = pos_t[:, None, :] - pos_t[None, :, :]
            dists = np.sqrt(np.sum(diff**2, axis=-1))
            max_dist_at_t.append(np.max(dists))
        
        t_meet = np.argmin(max_dist_at_t)
        meet_pos = np.mean(group_traj[:, :, t_meet], axis=0)
        label = f"R{'-R'.join([str(m+1) for m in members])} met"
        configs.append({'pos': meet_pos, 'label': label, 'members': members})
    return configs

def animate_with_dynamic_meetings(Xall, filename="R2AMCA_Linear_Final.gif"):
    num_agents, state_dim, total_steps = Xall.shape
    tikz_colors = [(0.066, 0.443, 0.745), (0.866, 0.329, 0.000), (0.929, 0.694, 0.125),
                   (0.521, 0.086, 0.819), (0.231, 0.666, 0.196), (0.184, 0.745, 0.937), (0.819, 0.015, 0.545)]
    warmorange = (1, 0.6, 0)
    
    # Generate automated meeting data
    raw_configs = get_automated_meeting_configs(Xall, meeting_groups_mat)
    
    fig, ax = plt.subplots(figsize=(12, 6))
    ax.set_xlim(-1, 49); ax.set_ylim(-1, 21); ax.set_aspect('equal')
    ax.set_xlabel('z position'); ax.set_ylabel('y position')

    # Draw Obstacles and Target Regions
    obs_coords = [2.85, 20.85, 38.55] 
    for i, x_s in enumerate(obs_coords):
        ax.add_patch(patches.Rectangle((x_s, 8.6), 6.3, 2.8, facecolor=warmorange, alpha=0.9, edgecolor='black'))
        ax.text(x_s + 3.15, 10, f"☐ ¬O{i+1}", ha='center', va='center', fontsize=10, fontweight='bold')

    for i in range(10):
        c = tikz_colors[i % len(tikz_colors)]
        cx = 1.5 + (i * 5)
        ax.add_patch(patches.RegularPolygon((cx, 1.5), 4, radius=1.5, facecolor=c, alpha=0.2, edgecolor='black'))
        ax.text(cx, -0.3, f"R{i+1} Collects", ha='center', va='top', fontsize=8)
        ax.add_patch(patches.Rectangle((cx-1, 18), 2, 2, facecolor=c, alpha=0.35, edgecolor='black'))
        ax.text(cx, 20.3, f"R{i+1} Delivers", ha='center', va='bottom', fontsize=8)

    # Initialize Meeting Circles and Text
    meeting_patches = []
    meeting_texts = []
    sorted_configs = sorted(raw_configs, key=lambda x: x['pos'][1])
    placed_label_positions = []

    for cfg in sorted_configs:
        pos = cfg['pos']
        p = plt.Circle(pos, 0.7, color='green', alpha=0, zorder=4)
        y_offset = 1.1
        for prev_pos in placed_label_positions:
            if np.linalg.norm(pos - prev_pos) < 2.0:
                y_offset += 1.0
        
        t = ax.text(pos[0], pos[1] + y_offset, cfg['label'], 
                    fontsize=8, ha='center', alpha=0, fontweight='bold',
                    bbox=dict(facecolor='white', alpha=0, edgecolor='none', pad=0.5))
        
        placed_label_positions.append(pos)
        ax.add_patch(p)
        meeting_patches.append(p)
        meeting_texts.append(t)

    # Initialize Agents
    # 1. Path lines
    lines = [ax.plot([], [], lw=1.5, alpha=0.8, color=tikz_colors[i % 7])[0] for i in range(num_agents)]
    
    # 2. Start positions
    for i in range(num_agents):
        ax.plot(Xall[i, 0, 0], Xall[i, 1, 0], 's', color=tikz_colors[i % 7], markersize=6, markeredgecolor='black', alpha=0.5)

    # 3. Agent Heads (Since we have no theta, we use a scatter plot for the current position)
    heads = ax.scatter(Xall[:, 0, 0], Xall[:, 1, 0], color=[tikz_colors[i % 7] for i in range(num_agents)], 
                       s=40, edgecolor='black', zorder=10)

    def update(frame):
        # Update Heads
        heads.set_offsets(Xall[:, :2, frame])
        
        # Update Path Lines (frame+1 ensures connection)
        for i in range(num_agents):
            lines[i].set_data(Xall[i, 0, :frame+1], Xall[i, 1, :frame+1])

        # Update Meeting Visibility
        threshold = 1.2
        for idx, cfg in enumerate(sorted_configs):
            agents_involved = cfg['members']
            pos = cfg['pos']
            dists = [np.linalg.norm(Xall[a, :2, frame] - pos) for a in agents_involved]
            
            if all(d < threshold for d in dists):
                meeting_patches[idx].set_alpha(0.3) 
                meeting_texts[idx].set_alpha(1.0)
                meeting_texts[idx].get_bbox_patch().set_alpha(0.6)
            
        return lines + [heads] + meeting_patches + meeting_texts

    ani = FuncAnimation(fig, update, frames=range(0, total_steps, 1), blit=True, interval=50)
    ani.save(filename, writer=PillowWriter(fps=20))
    plt.close()
    display(Image(filename=filename))

# Run the animation
animate_with_dynamic_meetings(Xall_R2AM_LBFGS_linear)