In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

from pncbf.dyn.segway import Segway
from pncbf.dyn.sim_cts import SimCtsReal

import os

In [None]:
segway = Segway()

In [None]:
# jax memory settings
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.5'

In [None]:
def nominal_policy(state):
    return segway.nom_pol_lqr(state)

In [None]:
def sample_segway_states(n_samples):
    rng = np.random.default_rng(42)
    
    # defines state space regions
    p_bounds = [-3.0, 3.0]         # pos
    theta_bounds = [-0.5*np.pi, 0.5*np.pi]  # angle
    v_bounds = [-5.0, 5.0]         # vel
    omega_bounds = [-8.0, 8.0]     # angula vel
    
    # stratified sampling
    safe_count = int(0.4 * n_samples) #40%
    boundary_count = int(0.3 * n_samples) #30%
    unsafe_count = n_samples - safe_count - boundary_count #rest
    
    
    # samples likely safe states (near equilibrium)
    safe_states = []
    for _ in range(safe_count):
        p = rng.uniform(-1.5, 1.5)
        theta = rng.uniform(-0.2*np.pi, 0.2*np.pi)
        v = rng.uniform(-3.0, 3.0)
        omega = rng.uniform(-4.0, 4.0)
        safe_states.append([p, theta, v, omega])
    
    # samples boundary states (near constraint boundaries)
    boundary_states = []
    for _ in range(boundary_count):
        # samples near either position or angle constraints
        if rng.random() < 0.5:
            # position boundary
            p = rng.choice([-2.0, 2.0]) + rng.uniform(-0.3, 0.3)
            theta = rng.uniform(-0.2*np.pi, 0.2*np.pi)
        else:
            # angle boundary
            p = rng.uniform(-1.5, 1.5)
            theta = rng.choice([-0.3*np.pi, 0.3*np.pi]) + rng.uniform(-0.05*np.pi, 0.05*np.pi)
            
        v = rng.uniform(-3.0, 3.0)
        omega = rng.uniform(-4.0, 4.0)
        boundary_states.append([p, theta, v, omega])
    
    # samples unsafe states
    unsafe_states = []
    for _ in range(unsafe_count):
        p = rng.uniform(*p_bounds)
        theta = rng.uniform(*theta_bounds)
        v = rng.uniform(*v_bounds)
        omega = rng.uniform(*omega_bounds)
        unsafe_states.append([p, theta, v, omega])
    
    # combines the states n shuffle
    all_states = np.array(safe_states + boundary_states + unsafe_states)
    rng.shuffle(all_states)
    
    return all_states

In [None]:
def collect_segway_data(system, policy, initial_states, sim_time=5.0, chunk_size=250):
    
    sim = SimCtsReal(system, policy, sim_time, system.dt)
    
    # trigger compilation using a dummy state
    dummy_state = jnp.array([0.0, 0.0, 0.0, 0.0])  # [p, theta, v, omega]
    try:
        _ = sim.rollout_plot(dummy_state)
        print("pre-compilation: Done!")
    except Exception as e:
        print(f"pre-compilation failed: {e}")
    
    # process in chunks --> prevent memory issues
    max_violations = []
    n_chunks = (len(initial_states) + chunk_size - 1) // chunk_size
    
    print(f"processing {len(initial_states)} states in {n_chunks} chunks of {chunk_size}")
    
    for chunk_idx in range(0, len(initial_states), chunk_size):
        
        chunk_end = min(chunk_idx + chunk_size, len(initial_states))
        chunk_states = initial_states[chunk_idx:chunk_end]
        
        print(f"\n--- chunk {chunk_idx//chunk_size + 1}/{n_chunks} ---")
        print(f"states {chunk_idx} to {chunk_end-1}")
        
        chunk_violations = []
        for i, x0 in enumerate(chunk_states):
            try:
                T_states, T_times, _ = sim.rollout_plot(x0)

                # constraint function
                h_values = []
                for state in T_states:
                    # get all constraint components
                    h_components = system.h_components(state)
                    # find maximum violation
                    h_val = jnp.max(h_components)
                    h_values.append(h_val)
                
                max_violation = np.max(h_values)
                chunk_violations.append(max_violation)
                
                if i % 50 == 0:
                    print(f"  processed {i}/{len(chunk_states)} in current chunk")
                    
            except Exception as e:
                print(f"  error at state {i} in chunk: {e}")
                chunk_violations.append(10.0)  # large values for failed states
        
        max_violations.extend(chunk_violations)
        
        intermediate_data = {
            'states': initial_states[:chunk_end],
            'violations': np.array(max_violations)
        }
        np.save(f'segway_data_chunk_{chunk_end}.npy', intermediate_data)
        
        # free memory after each chunk
        jax.clear_caches()
        
        print(f"total progress: {chunk_end}/{len(initial_states)} ({chunk_end/len(initial_states)*100:.1f}%)")
    
    return np.array(max_violations)

In [None]:
def visualize_segway_data(states, violations, plotname):
    plt.figure(figsize=(16, 12))
    
    # angle vs angle vel
    plt.subplot(2, 2, 1)
    plt.scatter(states[:, 1], states[:, 3], c=violations, cmap='RdYlGn_r', vmin=-1, vmax=1)
    plt.axvline(-0.3*np.pi, color='xkcd:grey', linestyle='--')
    plt.axvline(0.3*np.pi, color='xkcd:grey', linestyle='--')
    plt.xlabel(r'Angle $(\theta)$')
    plt.ylabel(r'Angular Velocity $(\omega)$')
    plt.title('Angle vs Angular Velocity')
    plt.colorbar()
    
    # pos vs vel
    plt.subplot(2, 2, 2)
    plt.scatter(states[:, 0], states[:, 2], c=violations, cmap='RdYlGn_r', vmin=-1, vmax=1)
    plt.axvline(-2.0, color='xkcd:grey', linestyle='--')
    plt.axvline(2.0, color='xkcd:grey', linestyle='--')
    plt.xlabel('Position (p)')
    plt.ylabel('Velocity (v)')
    plt.title('Position vs Velocity')
    plt.colorbar()
    
    # values distribution
    plt.subplot(2, 2, 3)
    plt.hist(violations, bins=50)
    plt.axvline(0, color='xkcd:grey', linestyle='--')
    plt.xlabel('Constraint Violation')
    plt.ylabel('Count')
    plt.title('Distribution of Maximum Constraint Violations')
    
    # pos vs angle
    plt.subplot(2, 2, 4)
    plt.scatter(states[:, 0], states[:, 1], c=violations, cmap='RdYlGn_r', vmin=-1, vmax=1)
    plt.axhline(-0.3*np.pi, color='xkcd:grey', linestyle='--')
    plt.axhline(0.3*np.pi, color='xkcd:grey', linestyle='--')
    plt.axvline(-2.0, color='xkcd:grey', linestyle='--')
    plt.axvline(2.0, color='xkcd:grey', linestyle='--')
    plt.xlabel('Position (p)')
    plt.ylabel(r'Angle $(\theta)$')
    plt.title('Position vs Angle')
    plt.colorbar()
    
    plt.tight_layout()
    plt.savefig(plotname + '.png', format='png', dpi=300, bbox_inches='tight', 
                facecolor='white', edgecolor='none')
    plt.savefig(plotname + '.pdf', format='pdf', bbox_inches='tight', 
                facecolor='white', edgecolor='none')
    plt.savefig(plotname + '.eps', format='eps', bbox_inches='tight', 
                facecolor='white', edgecolor='none')
    plt.show()

In [None]:
initial_states = sample_segway_states(n_samples=10000)

max_violations = collect_segway_data(
    segway, 
    nominal_policy,
    initial_states,
    sim_time=5.0,
    chunk_size=250
)

In [None]:
training_data = {
    'states': initial_states,
    'violations': max_violations
}
np.save('segway_training_data_10k.npy', training_data)

visualize_segway_data(initial_states, max_violations)

print("data collection: Done!")

In [None]:
def analyze_data_distribution(training_data):
    states = training_data['states']
    values = training_data['violations']
    
    
    plt.figure(figsize=(15, 5))
    # value distribution
    plt.subplot(1, 3, 1)
    plt.hist(values, bins=50, edgecolor='black')
    plt.axvline(0, color='r', linestyle='--', label='Safety Threshold')
    plt.xlabel('Maximum Constraint Violation')
    plt.ylabel('Count')
    plt.title('Distribution of Values')
    plt.legend()
    
    # spatial distribution
    plt.subplot(1, 3, 2)
    plt.scatter(states[:, 0], states[:, 1], c=values, cmap='RdYlGn_r', alpha=0.6)
    plt.colorbar(label='Max Violation')
    plt.xlabel('Position')
    plt.ylabel('Angle')
    plt.title('Spatial Distribution of States')
    
    # safe vs unsafe ratio
    plt.subplot(1, 3, 3)
    safe_ct = np.sum(values <= 0)
    unsafe_ct = np.sum(values > 0)
    boundary_ct = np.sum(np.abs(values) < 0.1)
    
    categories = ['Safe\n'r'(h$\leq$0)', 'Unsafe\n'r'(h>0)', 'Near Boundary\n'r'(|h|<0.1)']
    counts = [safe_ct, unsafe_ct, boundary_ct]
    colors = ['green', 'red', 'orange']
    
    plt.bar(categories, counts, color=colors, alpha=0.7)
    plt.ylabel('Number of States')
    plt.title('States Distribution')
    
    plt.tight_layout()
    plt.show()
    
    print(f"total samples: {len(values)}")
    print(f"safe states: {safe_ct} ({safe_ct/len(values)*100:.1f}%)")
    print(f"unsafe states: {unsafe_ct} ({unsafe_ct/len(values)*100:.1f}%)")
    print(f"near boundary: {boundary_ct} ({boundary_ct/len(values)*100:.1f}%)")
    print(f"value range: [{values.min():.3f}, {values.max():.3f}]")

In [None]:
data = np.load('segway_training_data_10k.npy', allow_pickle=True).item()
initial_states = data['states']
max_violations = data['violations']

In [None]:
visualize_segway_data(initial_states, max_violations, plotname='data_visualiztn')

In [None]:
analyze_data_distribution(data)

### Poster

In [None]:
def segwa_data_visual_pos_n_angle(states, violations, plotname):
    plt.figure(figsize=(16, 12))
    
    plt.scatter(states[:, 0], states[:, 1], c=violations, cmap='RdYlGn_r', vmin=-1, vmax=1)
    plt.axhline(-0.3*np.pi, color='xkcd:grey', linestyle='--')
    plt.axhline(0.3*np.pi, color='xkcd:grey', linestyle='--')
    plt.axvline(-2.0, color='xkcd:grey', linestyle='--')
    plt.axvline(2.0, color='xkcd:grey', linestyle='--')
    plt.xlabel('Position (p)')
    plt.ylabel(r'Angle $(\theta)$')
    plt.title('Position vs Angle')
    plt.colorbar()
    plt.tight_layout()
    
    plt.savefig(plotname + '.png', format='png', dpi=300, bbox_inches='tight', 
                facecolor='white', edgecolor='none')
    plt.savefig(plotname + '.pdf', format='pdf', bbox_inches='tight', 
                facecolor='white', edgecolor='none')
    plt.savefig(plotname + '.eps', format='eps', bbox_inches='tight', 
                facecolor='white', edgecolor='none')
    
    
    plt.show()

In [None]:
def segwa_data_visual_angle_n_angleVel(states, violations, plotname):
    plt.figure(figsize=(16, 12))
    plt.scatter(states[:, 1], states[:, 3], c=violations, cmap='RdYlGn_r', vmin=-1, vmax=1)
    plt.axvline(-0.3*np.pi, color='xkcd:grey', linestyle='--')
    plt.axvline(0.3*np.pi, color='xkcd:grey', linestyle='--')
    plt.xlabel(r'Angle $(\theta)$')
    plt.ylabel(r'Angular Velocity $(\omega)$')
    plt.title('Angle vs Angular Velocity')
    plt.colorbar()
    plt.tight_layout()
    plt.savefig(plotname + '.png', format='png', dpi=300, bbox_inches='tight', 
                facecolor='white', edgecolor='none')
    plt.savefig(plotname + '.pdf', format='pdf', bbox_inches='tight', 
                facecolor='white', edgecolor='none')
    plt.savefig(plotname + '.eps', format='eps', bbox_inches='tight', 
                facecolor='white', edgecolor='none')
    plt.show()

In [None]:
segwa_data_visual_pos_n_angle(states, max_violations, 'data_visual_pos_n_angle')
segwa_data_visual_angle_n_angleVel(states, max_violations, 'data_visual_angle_n_angleVel')