In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, lax, vmap
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
from collections import deque

# ===============================
# 1. Parameters
# ===============================
N, num_agents = 100, 10
dt, Gamma = 1.0, 10.0  # Increased Gamma for sharper satisfaction
alpha0, tol = 1.0, 1e-3
mu_meet = 0.5
w_v, w_w = 1.0, 20.0
maxIter, maxLineIters = 5000, 50
armijo_beta, armijo_c = 0.5, 1e-4
stall_window, stall_tol = 15, 1e-6 

x0_list = jnp.array([
    [1.5, 10, 0], [1.5, 15, 0], [11.5, 10, 0], [13.5, 10, 0], [15.5, 10, 0],
    [31.5, 10, 0], [33.5, 10, 0], [36.5, 10, 0], [47.0, 10, 0], [47.0, 15, 0]
])
O_centers = jnp.array([[6.0, 10.0], [24.0, 10.0], [42.0, 10.0]])
r_obs = jnp.array([3.162, 3.162, 3.162])
T_centers = jnp.array([jnp.array([1.5, 1.5]) + k*jnp.array([5.0, 0.0]) for k in range(num_agents)])
G_centers = jnp.array([jnp.array([1.5, 19.0]) + k*jnp.array([5.0, 0.0]) for k in range(num_agents)])
r_target, r_goal = jnp.full(num_agents, 0.8), jnp.full(num_agents, 0.8)

meeting_groups_mat = jnp.array([
    [0, 1, 2], [2, 3, -1], [0, 4, -1], [3, 4, -1], [3, 6, -1],
    [4, 5, -1], [6, 7, -1], [5, 7, -1], [5, 8, -1], [8, 9, -1], [7, 9, -1]
])

t1_low, t1_high = 11, 51
g1_low, g1_high = max(1, N-30), N
meet_low, meet_high = 0, 70

# ===============================
# 2. Differentiable Model
# ===============================
def smooth_min(vec):
    m = jnp.min(vec)
    return -(1/Gamma) * (jnp.log(jnp.sum(jnp.exp(-Gamma * (vec - m))) + 1e-6) - Gamma * m)

def smooth_max(vec):
    m = jnp.max(vec)
    return (1/Gamma) * (jnp.log(jnp.sum(jnp.exp(Gamma * (vec - m))) + 1e-6) + Gamma * m)

def step_unicycle(carry, u):
    x, y, th = carry
    next_state = jnp.array([x + dt*u[0]*jnp.cos(th), y + dt*u[0]*jnp.sin(th), th + dt*u[1]])
    return next_state, carry 

def simulate_agent_scan(x0, u_seq):
    last_state, states = lax.scan(step_unicycle, x0, u_seq.T)
    return jnp.vstack([states, last_state]).T

@jit
def total_rho(U):
    Xall = vmap(simulate_agent_scan)(x0_list, U)
    def agent_metrics(X, Tc, Gc, rt, rg):
        pos = X[:2, :].T
        d_obs = jnp.sqrt(jnp.sum((pos[:, None, :] - O_centers[None, :, :])**2, axis=-1) + 1e-9) - r_obs
        m_obs = vmap(smooth_min)(d_obs)
        d_goal = jnp.sqrt(jnp.sum((pos - Gc)**2, axis=-1) + 1e-9)
        d_target = jnp.sqrt(jnp.sum((pos - Tc)**2, axis=-1) + 1e-9)
        t_idx = jnp.arange(N+1)
        m_goal = jnp.where((t_idx >= g1_low) & (t_idx <= g1_high), rg - d_goal, -1e5)
        m_target = jnp.where((t_idx >= t1_low) & (t_idx <= t1_high), rt - d_target, -1e5)
        return smooth_min(jnp.array([smooth_min(m_obs), smooth_max(m_goal), smooth_max(m_target)]))

    rhos_agents = vmap(agent_metrics)(Xall, T_centers, G_centers, r_target, r_goal)
    
    def compute_group_rho(member_indices):
        mask = (member_indices != -1)
        safe_idx = jnp.where(mask, member_indices, 0)
        group_pos = Xall[safe_idx, :2, :] 
        diff = group_pos[:, None, :, :] - group_pos[None, :, :, :]
        dists = jnp.sqrt(jnp.sum(diff**2, axis=2) + 1e-9) 
        pair_mask = mask[:, None] * mask[None, :] * (jnp.eye(3) == 0)
        rho_at_t = mu_meet - jnp.max(jnp.where(pair_mask[:, :, None], dists, 0.0), axis=(0, 1))
        t_idx = jnp.arange(N+1)
        rho_window = jnp.where((t_idx >= meet_low) & (t_idx <= meet_high), rho_at_t, -1e5)
        return smooth_max(rho_window)

    rhos_groups = vmap(compute_group_rho)(meeting_groups_mat)
    return smooth_min(jnp.concatenate([rhos_agents, rhos_groups]))

@jit
def compute_full_objective1(U, lam):
    rho = total_rho(U)
    control_cost = jnp.sum(w_v * U[:, 0, :]**2 + w_w * U[:, 1, :]**2)
    penalty = lam * jnp.maximum(0.0, -rho)**2
    return control_cost + penalty

# Gradient of the TOTAL objective
rho_grad_fun = jit(grad(compute_full_objective1))

# ===============================
# 3. Optimization Loop
# ===============================
U = jnp.zeros((num_agents, 2, N))
rho_hist, gnorm_hist, cost_hist = [], [], []

# U = jnp.array(np.load("U_checkpoint.npy"))
# meta = np.load("checkpoint_meta.npz")

# lam = float(meta["lam"])
# global_iter = int(meta["global_iter"])
# rho_hist = meta["rho_hist"].tolist()
# cost_hist = meta["cost_hist"].tolist()
# gnorm_hist = meta["gnorm_hist"].tolist()


cost_window = deque(maxlen=stall_window) 
lam, global_iter = 10.0, 0

print(f"{'Iter':<8} | {'rho':<10} | {'Total Cost':<12} | {'gnorm':<10} | {'lam':<8}")
print("-" * 65)

try:
    for out_it in range(50):
        # Apply Hessian approximation to the full gradient later
        H_inv = jnp.array([1.0/(2*w_v), 1.0/(2*w_w)]) # Approximate Hessian for control
        cost_window.clear() 
        
        for epoch in range(maxIter):
            indices = np.random.permutation(num_agents)
            for blk in indices:
                # Current base cost and robustness
                rho_curr = total_rho(U)
                J_base = compute_full_objective1(U, lam)
                
                # Full gradient (Control + Penalty)
                full_grad = rho_grad_fun(U, lam)
                curr_gnorm = float(jnp.linalg.norm(full_grad))
                
                # Histories
                rho_hist.append(float(rho_curr))
                cost_hist.append(float(J_base))
                gnorm_hist.append(curr_gnorm)
                
                # UPDATE: Use the full_grad[blk] directly
                g_blk = full_grad[blk]
                d_blk = -g_blk * H_inv[:, None]
                slope = jnp.sum(g_blk * d_blk)
                
                # Line Search
                alpha, U_old_blk = alpha0, U[blk]
                for _ in range(maxLineIters):
                    U_trial = U.at[blk].set(U_old_blk + alpha * d_blk)
                    J_trial = compute_full_objective1(U_trial, lam)
                    if J_trial <= J_base + armijo_c * alpha * slope:
                        U, J_base = U_trial, J_trial
                        break
                    alpha *= armijo_beta

                cost_window.append(float(J_base))
            
            # Convergence Checks
            if len(cost_window) == stall_window:
                rel_change = abs(cost_window[0] - cost_window[-1]) / (abs(cost_window[0]) + 1e-9)
                if curr_gnorm < tol: break
                if rel_change < stall_tol and rho_curr < -0.01: break

            global_iter += 1
            if global_iter % 20 == 0:
                print(f"{global_iter:<8} | {rho_curr:>10.4f} | {float(J_base):>12.2f} | {curr_gnorm:>10.4f} | {lam:>8.1f}")

        if rho_curr >= -1e-3: 
            print("Satisfied!"); break
        
        # Increase penalty
        lam = min(lam * 2.0, 50000.0)

except KeyboardInterrupt: pass

# ===============================
# 4. Plots
# ===============================# Ensure all histories are numpy-friendly
rho_np = np.array(rho_hist)
cost_np = np.array(cost_hist)
gnorm_np = np.array(gnorm_hist)

fig, axes = plt.subplots(1, 3, figsize=(18, 4))

# Plot Robustness
axes[0].plot(rho_np, color='blue')
axes[0].set_title("Robustness ($\rho$)")
axes[0].grid(True)
axes[0].axhline(0, color='red', linestyle='--')

# Plot Total Cost (with safety check for log scale)
axes[1].plot(cost_np, color='purple')
axes[1].set_title("Total Objective ($J$)")
axes[1].grid(True)
if len(cost_np) > 0 and np.all(cost_np > 0):
    axes[1].set_yscale('log')

# Plot Gradient Norm
axes[2].plot(gnorm_np, color='green')
axes[2].set_title("Gradient Norm $||\nabla J||$")
axes[2].grid(True)
axes[2].set_yscale('log')

plt.tight_layout()
plt.show()

# 1. Generate final trajectories from the last U
Xall_final = np.array(vmap(simulate_agent_scan)(x0_list, U))

plt.figure(figsize=(10, 7))

# 2. Draw Obstacles
for c, r in zip(O_centers, r_obs):
    circle = patches.Circle(c, r, color='red', alpha=0.3, label='Obstacle' if c[0]==O_centers[0,0] else "")
    plt.gca().add_patch(circle)

# 3. Plot each agent's path
for i in range(num_agents):
    # Xall_final[agent_index, state_dim (x=0, y=1), time_step]
    plt.plot(Xall_final[i, 0, :], Xall_final[i, 1, :], label=f'Agent {i}', linewidth=1.5)
    # Plot starting point
    plt.scatter(Xall_final[i, 0, 0], Xall_final[i, 1, 0], marker='o')
    # Plot end point
    plt.scatter(Xall_final[i, 0, -1], Xall_final[i, 1, -1], marker='x')

plt.gca().set_aspect('equal')
plt.grid(True, linestyle='--', alpha=0.6)
plt.title(f"Final Trajectories (Final $\rho$: {rho_np[-1]:.4f})")
plt.xlabel("X Position")
plt.ylabel("Y Position")
plt.legend(loc='upper right', bbox_to_anchor=(1.2, 1))
plt.show()
