This code is used to compare two files and print the statistics of the differences. It takes into consideration possibly failed trajectories that are documented in a third file.


In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Tuple, List, Optional

In [3]:

def load_trajectories(path: Path) -> List[np.ndarray]:
    """Load a CSV file where each row is a trajectory of semicolon-separated states."""
    trajectories = []
    with open(path, 'r') as f:
        for line in f:
            steps = line.strip().split(',')
            trajectory = [list(map(float, s.split(';'))) for s in steps]
            trajectories.append(np.array(trajectory))
    return trajectories

In [4]:

def load_failed_indices(path: Optional[Path]) -> set:
    """Load indices of failed trajectories from a text file."""
    if path is None:
        return set()
    try:
        with open(path, 'r') as f:
            return set(int(line.strip()) for line in f if line.strip().isdigit())
    except FileNotFoundError:
        print(f"Warning: File {path} not found. Returning empty set.")
        return set()

In [5]:
def compute_rmse_metrics(ref_traj: List[np.ndarray], eval_traj: List[np.ndarray], failed: set) -> Tuple[dict, List[float], pd.DataFrame]:
    """Compute RMSE statistics per trajectory, excluding failed ones."""
    skipped = 0
    per_traj_rmses = []
    traj_stats = []
    for idx in range(len(ref_traj)):
        if idx in failed:
            skipped += 1
            print(f"Skipping failed trajectory {idx+1}.")
        if idx >= len(eval_traj):
            print(f"Finished: All the trajectories were evaluated already at trajectory {idx+1}.")
            break

        ref = ref_traj[idx+skipped]
        eval_ = eval_traj[idx]

        if ref.shape[0] != eval_.shape[0]:
            if ref.shape[0] > eval_.shape[0]:
                ref = ref[1:, :]
            else:
                print(f"Warning: Reference trajectory {idx} is shorter than evaluated trajectory. Skipping.")
                continue

        if ref.shape != eval_.shape:
            print(f"Warning: Shape mismatch at trajectory {idx}. Skipping.")
            continue

        error = ref - eval_
        rmse = np.sqrt(np.mean(np.square(error)))
        per_traj_rmses.append(rmse)
        traj_stats.append({
            'trajectory': idx,
            'rmse': rmse,
            'std': np.std(error),
            'max': np.max(np.abs(error))
        })

    # Global metrics
    global_metrics = {
        'global_rmse_mean': np.mean(per_traj_rmses),
        'global_rmse_std': np.std(per_traj_rmses),
        'global_rmse_max': np.max(per_traj_rmses),
    }

    return global_metrics, per_traj_rmses, pd.DataFrame(traj_stats)

In [8]:

def compare_trajectories(reference_path: Path, constrained_path: Path, failed_path: Optional[Path] = None, verbose: bool = True):
    """Main function to compare reference and generated trajectories."""
    reference = load_trajectories(reference_path)
    constrained = load_trajectories(constrained_path)
    if failed_path != "":
        failed = load_failed_indices(failed_path)
    else:
        failed = set()

    global_metrics, per_traj_rmses, traj_df = compute_rmse_metrics(reference, constrained, failed)

    # Plot histogram
    plt.figure(figsize=(8, 5))
    plt.hist(per_traj_rmses, bins=30, color="skyblue", edgecolor="black")
    plt.title("Per-Trajectory RMSE Histogram")
    plt.xlabel("RMSE")
    plt.ylabel("Frequency")
    plt.grid(True)
    plt.tight_layout()
    if verbose:
        histogram_path = Path("rmse_histogram.pdf")
        plt.savefig(histogram_path)
    plt.close()

    if verbose:
        # Save trajectory-wise metrics
        traj_df.to_csv("trajectory_rmse_stats.csv", index=False)

    return global_metrics, traj_df

In [9]:
from pathlib import Path

DATASET_DIR = "/home/jeauscq/Desktop/jcq_thesis/datasets/"
ref_path = DATASET_DIR + "Policy/training_trajectories_policy_n.csv"
gen_path = DATASET_DIR + "Policy/MPC/torque_Const/mpc_generated_tor_constraint_dataset_n.csv"
fail_path = DATASET_DIR + "Policy/MPC/pos_vel_Const/failed_constrained_trajectories.txt"

metrics, traj_df = compare_trajectories(ref_path, gen_path, "", verbose=False)
print(metrics)


{'global_rmse_mean': 0.0583133431338275, 'global_rmse_std': 0.05844708784972515, 'global_rmse_max': 1.4603586645523172}
