# C++ vs JAX Newton-Raphson Comparison

Compares the C++ and JAX implementations of the Newton-Raphson controller on the same trajectory log files.

**Log files:**
- `log_files/newton_raphson_px4/` — JAX implementation
- `log_files/newton_raphson_px4_cpp/` — C++ implementation

**Contents:**
1. Trajectory comparison plots (2-row grid, same style as ControllerComparison.ipynb)
2. Computation time distributions (box plots + histograms)
3. RMSE table
4. Computation time summary table (mean, std, min, max, 1/mean)
5. Speed-up summary
6. Mean computation time bar chart
7. Control inputs over time (throttle, p, q, r)
8. CBF correction values over time

## 1. Import Libraries

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Force reload utilities module
import importlib
import sys
if 'utilities' in sys.modules:
    importlib.reload(sys.modules['utilities'])

from utilities import (
    load_csv,
    detect_trajectory_plane,
    plot_trajectory_2d,
    calculate_position_rmse,
    align_reference_to_actual,
)

plt.style.use('default')
plt.rcParams['text.usetex'] = False
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Liberation Sans', 'Arial']

## 2. Configuration

In [None]:
# ===== CONFIGURATION =====
JAX_LOG_DIR  = "log_files/newton_raphson_px4/"
CPP_LOG_DIR  = "log_files/newton_raphson_px4_cpp/"

# Labels used in plots and tables
JAX_LABEL = "JAX"
CPP_LABEL = "C++"

# Plot styling
SUBFIG_TITLE_FONTSIZE = 14
TICK_SIZE = 10
SHOW_GRIDLINES = False

# Output directory
output_path = "output/"
os.makedirs(output_path, exist_ok=True)
# =========================

print(f"JAX logs : {JAX_LOG_DIR}")
print(f"C++ logs : {CPP_LOG_DIR}")
print(f"Output   : {output_path}")

## 3. Load Data

In [None]:
def load_dir_csvs(directory: str) -> dict:
    """Load all CSV files from a directory, keyed by stem filename."""
    result = {}
    for p in sorted(Path(directory).glob("*.csv")):
        result[p.stem] = load_csv(str(p))
    return result


jax_data = load_dir_csvs(JAX_LOG_DIR)
cpp_data = load_dir_csvs(CPP_LOG_DIR)

# Find common file stems
common_keys = sorted(set(jax_data.keys()) & set(cpp_data.keys()))
jax_only   = sorted(set(jax_data.keys()) - set(cpp_data.keys()))
cpp_only   = sorted(set(cpp_data.keys()) - set(jax_data.keys()))

print(f"JAX files  : {list(jax_data.keys())}")
print(f"C++ files  : {list(cpp_data.keys())}")
print(f"Common     : {common_keys}")
if jax_only:
    print(f"JAX only   : {jax_only}")
if cpp_only:
    print(f"C++ only   : {cpp_only}")

## 4. Trajectory Comparison Plots

Two-row grid: **JAX** on top, **C++** on bottom.  
Each column is one log file (trajectory).  
Red = actual, blue dashed = reference.

In [None]:
def prep_df(df: pd.DataFrame) -> pd.DataFrame:
    """NED → ENU z-flip and lookahead alignment."""
    df = df.copy()
    df['z']     = -df['z']
    df['z_ref'] = -df['z_ref']
    df = align_reference_to_actual(df, sampling_rate=10.0)
    return df


def make_label(stem: str) -> str:
    """Convert a filename stem to a human-readable trajectory label."""
    # e.g. 'sim_nr_std_circle_horz_2x' → 'Circle Horz 2x'
    parts = stem.split('_')
    # drop leading 'sim'/'hw' and controller tokens (nr, std, mpc, ...)
    skip = {'sim', 'hw', 'nr', 'std', 'enhanced', 'mpc'}
    kept = [p for p in parts if p.lower() not in skip]
    return ' '.join(p.capitalize() for p in kept)


keys_to_plot = common_keys if common_keys else list(jax_data.keys())
n_cols = len(keys_to_plot)

fig, axes = plt.subplots(2, n_cols,
                         figsize=(4.0 * n_cols, 9),
                         squeeze=False)

rows = [
    (jax_data, JAX_LABEL),
    (cpp_data, CPP_LABEL),
]

subplot_letter = ord('a')

for row_idx, (data_dict, impl_label) in enumerate(rows):
    for col_idx, key in enumerate(keys_to_plot):
        ax = axes[row_idx, col_idx]
        letter = chr(subplot_letter)
        subplot_letter += 1
        traj_label = make_label(key)

        if key in data_dict:
            df = prep_df(data_dict[key])
            plane = detect_trajectory_plane(df)

            plot_trajectory_2d(ax, df, plane=plane,
                               actual_color='red', ref_color='blue',
                               actual_label='Actual', ref_label='Reference',
                               flip_z=False, align_lookahead=False)

            # remove legend from individual subplots
            leg = ax.get_legend()
            if leg:
                leg.remove()

            rmse = calculate_position_rmse(df, flip_z=False, align_lookahead=False)
            ax.set_title(f"{letter}) {impl_label}: {traj_label}\nRMSE = {rmse*100:.2f} cm",
                         fontsize=SUBFIG_TITLE_FONTSIZE)
        else:
            ax.text(0.5, 0.5, 'No data', ha='center', va='center',
                    transform=ax.transAxes, color='gray', fontsize=12)
            ax.set_title(f"{letter}) {impl_label}: {traj_label}",
                         fontsize=SUBFIG_TITLE_FONTSIZE)

        ax.grid(SHOW_GRIDLINES)
        ax.tick_params(labelsize=TICK_SIZE)

# Shared legend in figure
from matplotlib.lines import Line2D
legend_elements = [
    Line2D([0], [0], color='red',  linestyle='-',  label='Actual'),
    Line2D([0], [0], color='blue', linestyle='--', label='Reference'),
]
fig.legend(handles=legend_elements, loc='lower center', ncol=2,
           fontsize=12, frameon=True, bbox_to_anchor=(0.5, -0.02))

plt.tight_layout(rect=[0, 0.04, 1, 1])

save_path = f"{output_path}cpp_vs_jax_trajectories.pdf"
fig.savefig(save_path, format='pdf', bbox_inches='tight', dpi=300)
print(f"Saved: {save_path}")
plt.show()

## 5. Computation Time Analysis

Distributions and comparison of per-step computation times.

> **Note:** The first JAX sample is typically large due to JIT compilation warm-up.  
> Both raw and warm-up-excluded statistics are shown.

In [None]:
# Collect comp_time series (in microseconds for readability)
comp_records = []
for key in keys_to_plot:
    label = make_label(key)
    for impl_label, data_dict in [(JAX_LABEL, jax_data), (CPP_LABEL, cpp_data)]:
        if key not in data_dict:
            continue
        df = data_dict[key]
        ct_us = df['comp_time'].dropna() * 1e6  # seconds → µs
        comp_records.append({
            'key': key,
            'label': label,
            'impl': impl_label,
            'comp_time_us': ct_us,
        })

# ── Box plots ──────────────────────────────────────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

for ax_idx, (ax, skip_first) in enumerate(zip(axes, [False, True])):
    group_data = []
    group_labels = []
    colors = []
    palette = {JAX_LABEL: '#4C72B0', CPP_LABEL: '#DD8452'}

    for rec in comp_records:
        ct = rec['comp_time_us'].iloc[1:] if skip_first else rec['comp_time_us']
        group_data.append(ct.values)
        group_labels.append(f"{rec['impl']}\n{rec['label']}")
        colors.append(palette[rec['impl']])

    bp = ax.boxplot(group_data, patch_artist=True, notch=False,
                    medianprops=dict(color='black', linewidth=2))
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
        patch.set_alpha(0.7)

    ax.set_xticks(range(1, len(group_labels) + 1))
    ax.set_xticklabels(group_labels, fontsize=10)
    ax.set_ylabel('Computation time [µs]', fontsize=12)
    title_suffix = " (excl. 1st sample)" if skip_first else " (all samples)"
    ax.set_title(f"Computation Time Distribution{title_suffix}", fontsize=13)
    ax.grid(True, axis='y', alpha=0.3)

    # Add legend patches
    from matplotlib.patches import Patch
    legend_patches = [Patch(facecolor=c, alpha=0.7, label=l)
                      for l, c in palette.items()]
    ax.legend(handles=legend_patches, fontsize=10)

plt.tight_layout()
save_path = f"{output_path}cpp_vs_jax_comp_time_boxplot.pdf"
fig.savefig(save_path, format='pdf', bbox_inches='tight', dpi=300)
print(f"Saved: {save_path}")
plt.show()

In [None]:
# ── Histogram / KDE overlay per log file ──────────────────────────────────
n_files = len(keys_to_plot)
fig, axes = plt.subplots(1, n_files, figsize=(6 * n_files, 4), squeeze=False)

palette = {JAX_LABEL: '#4C72B0', CPP_LABEL: '#DD8452'}

for col_idx, key in enumerate(keys_to_plot):
    ax = axes[0, col_idx]
    label = make_label(key)
    for impl_label, data_dict in [(JAX_LABEL, jax_data), (CPP_LABEL, cpp_data)]:
        if key not in data_dict:
            continue
        ct = data_dict[key]['comp_time'].dropna().iloc[1:] * 1e6  # skip first, µs
        ax.hist(ct, bins=40, alpha=0.5, color=palette[impl_label],
                label=impl_label, density=True)

    ax.set_xlabel('Computation time [µs]', fontsize=11)
    ax.set_ylabel('Density', fontsize=11)
    ax.set_title(f"{label}\n(excl. 1st sample)", fontsize=12)
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)

plt.tight_layout()
save_path = f"{output_path}cpp_vs_jax_comp_time_hist.pdf"
fig.savefig(save_path, format='pdf', bbox_inches='tight', dpi=300)
print(f"Saved: {save_path}")
plt.show()

## 6. RMSE Table

In [None]:
rmse_rows = []

for key in keys_to_plot:
    traj_label = make_label(key)
    row = {'Trajectory': traj_label}
    for impl_label, data_dict in [(JAX_LABEL, jax_data), (CPP_LABEL, cpp_data)]:
        if key in data_dict:
            df = prep_df(data_dict[key])
            rmse = calculate_position_rmse(df, flip_z=False, align_lookahead=False)
            row[f'RMSE {impl_label} [m]'] = rmse
            row[f'RMSE {impl_label} [cm]'] = rmse * 100
        else:
            row[f'RMSE {impl_label} [m]']  = np.nan
            row[f'RMSE {impl_label} [cm]'] = np.nan
    rmse_rows.append(row)

rmse_df = pd.DataFrame(rmse_rows)

# Add difference column if both implementations present
jax_col = f'RMSE {JAX_LABEL} [cm]'
cpp_col = f'RMSE {CPP_LABEL} [cm]'
if jax_col in rmse_df.columns and cpp_col in rmse_df.columns:
    rmse_df['Δ RMSE (JAX−C++) [cm]'] = rmse_df[jax_col] - rmse_df[cpp_col]

print("Position RMSE Comparison")
print("=" * 60)
print(rmse_df.to_string(index=False, float_format=lambda x: f"{x:.4f}"))

rmse_df.to_csv(f"{output_path}cpp_vs_jax_rmse.csv", index=False)
print(f"\nSaved: {output_path}cpp_vs_jax_rmse.csv")

## 7. Computation Time Summary Table

Columns: mean, std, min, max, 1/mean (effective frequency).  
All times in **µs** (microseconds); frequency in **kHz**.

In [None]:
ct_rows = []

for key in keys_to_plot:
    traj_label = make_label(key)
    for impl_label, data_dict in [(JAX_LABEL, jax_data), (CPP_LABEL, cpp_data)]:
        if key not in data_dict:
            continue
        ct_all = data_dict[key]['comp_time'].dropna() * 1e6  # µs
        ct_warm = ct_all.iloc[1:]  # exclude first (JIT warmup)

        for label_suffix, ct in [('(all)', ct_all), ('(excl. 1st)', ct_warm)]:
            mean_us = ct.mean()
            ct_rows.append({
                'Trajectory':            traj_label,
                'Implementation':        impl_label,
                'Samples':               label_suffix,
                'n':                     len(ct),
                'Mean [µs]':             mean_us,
                'Std [µs]':              ct.std(),
                'Min [µs]':              ct.min(),
                'Max [µs]':              ct.max(),
                '1/Mean [kHz]':          (1.0 / (mean_us * 1e-6)) / 1e3,   # Hz → kHz
            })

ct_df = pd.DataFrame(ct_rows)

float_cols = ['Mean [µs]', 'Std [µs]', 'Min [µs]', 'Max [µs]', '1/Mean [kHz]']
fmt = {c: "{:.3f}".format for c in float_cols}

print("Computation Time Summary")
print("=" * 90)
print(ct_df.to_string(index=False,
                      formatters={c: (lambda x: f"{x:.3f}") for c in float_cols}))

ct_df.to_csv(f"{output_path}cpp_vs_jax_comp_times.csv", index=False)
print(f"\nSaved: {output_path}cpp_vs_jax_comp_times.csv")

## 8. Speed-up Summary

C++ speed-up factor relative to JAX (mean computation time, excluding first sample).

In [None]:
speedup_rows = []

for key in keys_to_plot:
    traj_label = make_label(key)
    if key not in jax_data or key not in cpp_data:
        continue

    jax_ct  = jax_data[key]['comp_time'].dropna().iloc[1:] * 1e6   # µs, excl. 1st
    cpp_ct  = cpp_data[key]['comp_time'].dropna().iloc[1:] * 1e6

    jax_mean = jax_ct.mean()
    cpp_mean = cpp_ct.mean()

    speedup_rows.append({
        'Trajectory':          traj_label,
        f'{JAX_LABEL} mean [µs]':  jax_mean,
        f'{CPP_LABEL} mean [µs]':  cpp_mean,
        'Speed-up (JAX/C++)':  jax_mean / cpp_mean,
        f'{JAX_LABEL} 1/mean [kHz]': (1.0 / (jax_mean * 1e-6)) / 1e3,
        f'{CPP_LABEL} 1/mean [kHz]': (1.0 / (cpp_mean * 1e-6)) / 1e3,
    })

speedup_df = pd.DataFrame(speedup_rows)

float_cols_s = [c for c in speedup_df.columns if c != 'Trajectory']
print("Speed-up Summary (excl. 1st sample)")
print("=" * 70)
print(speedup_df.to_string(index=False,
                           formatters={c: (lambda x: f"{x:.3f}") for c in float_cols_s}))

speedup_df.to_csv(f"{output_path}cpp_vs_jax_speedup.csv", index=False)
print(f"\nSaved: {output_path}cpp_vs_jax_speedup.csv")

## 9. Mean Computation Time Bar Chart

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
palette = {JAX_LABEL: '#4C72B0', CPP_LABEL: '#DD8452'}

for ax_idx, (ax, skip_first) in enumerate(zip(axes, [False, True])):
    x = np.arange(len(keys_to_plot))
    width = 0.35
    traj_labels = [make_label(k) for k in keys_to_plot]

    for i, (impl_label, data_dict) in enumerate([(JAX_LABEL, jax_data), (CPP_LABEL, cpp_data)]):
        means = []
        stds  = []
        for key in keys_to_plot:
            if key in data_dict:
                ct = data_dict[key]['comp_time'].dropna()
                ct = ct.iloc[1:] if skip_first else ct
                ct_us = ct * 1e6
                means.append(ct_us.mean())
                stds.append(ct_us.std())
            else:
                means.append(np.nan)
                stds.append(np.nan)

        offset = (i - 0.5) * width
        ax.bar(x + offset, means, width,
               yerr=stds, capsize=4,
               label=impl_label,
               color=palette[impl_label],
               alpha=0.8)

    ax.set_xticks(x)
    ax.set_xticklabels(traj_labels, fontsize=11)
    ax.set_ylabel('Mean computation time [µs]', fontsize=12)
    suffix = " (excl. 1st)" if skip_first else " (all samples)"
    ax.set_title(f"Mean Comp. Time ± Std{suffix}", fontsize=13)
    ax.legend(fontsize=11)
    ax.grid(True, axis='y', alpha=0.3)

plt.tight_layout()
save_path = f"{output_path}cpp_vs_jax_comp_time_bar.pdf"
fig.savefig(save_path, format='pdf', bbox_inches='tight', dpi=300)
print(f"Saved: {save_path}")
plt.show()

## 10. Control Inputs Over Time

Plots of `throttle_input`, `p_input`, `q_input`, `r_input` for each log file.  
JAX (blue) and C++ (orange) overlaid on the same axes per input channel.

In [None]:
CTRL_COLS = ['throttle_input', 'p_input', 'q_input', 'r_input']
CTRL_LABELS = [
    'Throttle input',
    r'$p$ input [rad/s]',
    r'$q$ input [rad/s]',
    r'$r$ input [rad/s]',
]

palette = {JAX_LABEL: '#4C72B0', CPP_LABEL: '#DD8452'}

for key in keys_to_plot:
    traj_label = make_label(key)
    n_inputs = len(CTRL_COLS)

    fig, axes = plt.subplots(n_inputs, 1, figsize=(12, 2.5 * n_inputs),
                             sharex=True)

    for impl_label, data_dict in [(JAX_LABEL, jax_data), (CPP_LABEL, cpp_data)]:
        if key not in data_dict:
            continue
        df = data_dict[key]
        # Use traj_time if available, else time
        t = df['traj_time'].values if 'traj_time' in df.columns else df['time'].values
        color = palette[impl_label]

        for ax, col, ylabel in zip(axes, CTRL_COLS, CTRL_LABELS):
            if col in df.columns:
                ax.plot(t, df[col].values, color=color,
                        label=impl_label, linewidth=1.2, alpha=0.85)
            ax.set_ylabel(ylabel, fontsize=10)
            ax.grid(True, alpha=0.3)

    axes[-1].set_xlabel('Trajectory time [s]', fontsize=11)
    fig.suptitle(f"Control Inputs — {traj_label}", fontsize=13, y=1.01)

    # Shared legend (deduplicated)
    handles, labels = axes[0].get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    fig.legend(by_label.values(), by_label.keys(),
               loc='upper right', fontsize=10, frameon=True)

    plt.tight_layout()
    save_path = f"{output_path}cpp_vs_jax_ctrl_inputs_{key}.pdf"
    fig.savefig(save_path, format='pdf', bbox_inches='tight', dpi=300)
    print(f"Saved: {save_path}")
    plt.show()

## 11. CBF Values Over Time

Plots of `cbf_v_throttle`, `cbf_v_p`, `cbf_v_q`, `cbf_v_r` — the CBF correction terms applied to each control channel.  
JAX (blue) and C++ (orange) overlaid on the same axes.

In [None]:
CBF_COLS = ['cbf_v_throttle', 'cbf_v_p', 'cbf_v_q', 'cbf_v_r']
CBF_LABELS = [
    'CBF throttle correction',
    r'CBF $p$ correction [rad/s]',
    r'CBF $q$ correction [rad/s]',
    r'CBF $r$ correction [rad/s]',
]

palette = {JAX_LABEL: '#4C72B0', CPP_LABEL: '#DD8452'}

for key in keys_to_plot:
    traj_label = make_label(key)
    n_cbf = len(CBF_COLS)

    fig, axes = plt.subplots(n_cbf, 1, figsize=(12, 2.5 * n_cbf),
                             sharex=True)

    for impl_label, data_dict in [(JAX_LABEL, jax_data), (CPP_LABEL, cpp_data)]:
        if key not in data_dict:
            continue
        df = data_dict[key]
        t = df['traj_time'].values if 'traj_time' in df.columns else df['time'].values
        color = palette[impl_label]

        for ax, col, ylabel in zip(axes, CBF_COLS, CBF_LABELS):
            if col in df.columns:
                ax.plot(t, df[col].values, color=color,
                        label=impl_label, linewidth=1.2, alpha=0.85)
            ax.axhline(0, color='gray', linewidth=0.8, linestyle='--')
            ax.set_ylabel(ylabel, fontsize=10)
            ax.grid(True, alpha=0.3)

    axes[-1].set_xlabel('Trajectory time [s]', fontsize=11)
    fig.suptitle(f"CBF Corrections — {traj_label}", fontsize=13, y=1.01)

    handles, labels = axes[0].get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    fig.legend(by_label.values(), by_label.keys(),
               loc='upper right', fontsize=10, frameon=True)

    plt.tight_layout()
    save_path = f"{output_path}cpp_vs_jax_cbf_{key}.pdf"
    fig.savefig(save_path, format='pdf', bbox_inches='tight', dpi=300)
    print(f"Saved: {save_path}")
    plt.show()

## Summary

| Output file | Description |
|---|---|
| `cpp_vs_jax_trajectories.pdf` | 2-row trajectory comparison plot |
| `cpp_vs_jax_comp_time_boxplot.pdf` | Box plots of per-step computation times |
| `cpp_vs_jax_comp_time_hist.pdf` | Histograms of computation time distributions |
| `cpp_vs_jax_comp_time_bar.pdf` | Mean ± std bar chart |
| `cpp_vs_jax_ctrl_inputs_<traj>.pdf` | Control inputs over time per trajectory |
| `cpp_vs_jax_cbf_<traj>.pdf` | CBF corrections over time per trajectory |
| `cpp_vs_jax_rmse.csv` | RMSE values for each trajectory and implementation |
| `cpp_vs_jax_comp_times.csv` | Full computation time statistics |
| `cpp_vs_jax_speedup.csv` | Speed-up of C++ relative to JAX |