# A1 Walk-Cycle Motion Planning with `ChainDynamicsGraph` (Python)

This notebook mirrors the C++ `newGraph()` pipeline in:

`examples/example_a1_walking/main.cpp`

The goal is to build the same multi-phase quadruped trajectory optimization with chain dynamics only, then inspect and plot the optimized motion.

## 1) Environment and Imports

Two details matter for this notebook:

1. We must import the local GTDynamics Python wrapper from `build/python`.
2. `gtsam` must come from the same build/install family as GTDynamics. Mixing with an unrelated pip/conda `gtsam` wheel can cause runtime aborts when factor graphs are modified.

In [None]:
from pathlib import Path
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


def find_repo_root(start: Path) -> Path:
    for candidate in [start, *start.parents]:
        if (candidate / 'CMakeLists.txt').exists() and (candidate / 'examples' / 'example_a1_walking').exists():
            return candidate
    raise RuntimeError('Could not locate GTDynamics repository root from current directory.')


repo_root = find_repo_root(Path.cwd().resolve())
gtd_build_python = repo_root / 'build' / 'python'
gtsam_build_python = repo_root.parent / 'gtsam' / 'build' / 'python'

if gtsam_build_python.exists() and str(gtsam_build_python) not in sys.path:
    sys.path.insert(0, str(gtsam_build_python))
if str(gtd_build_python) not in sys.path:
    sys.path.insert(0, str(gtd_build_python))

import gtsam
import gtdynamics as gtd

print('repo_root        :', repo_root)
print('gtsam module     :', gtsam.__file__)
print('gtdynamics module:', gtd.__file__)

if gtsam_build_python.exists() and str(gtsam_build_python) not in gtsam.__file__:
    raise RuntimeError(
        'Loaded a non-local gtsam package. Put '
        f"{gtsam_build_python} first on PYTHONPATH before running this notebook."
    )

## 2) Helper Functions (C++ `newGraph()` equivalents)

These helpers correspond to pieces in `main.cpp`:

- `build_walk_trajectory(...)` mirrors `getTrajectory(...)`.
- `set_massless_except_trunk(...)` mirrors the mass/inertia simplification used in `newGraph()`.
- `add_subgraph(...)` is needed because Python exposes graph `add(...)` for single factors, while many GTDynamics objective helpers return a small `NonlinearFactorGraph`.

In [None]:
def add_subgraph(dst_graph: gtsam.NonlinearFactorGraph,
                 src_graph: gtsam.NonlinearFactorGraph) -> None:
    # Append all factors from src_graph into dst_graph.
    for i in range(src_graph.size()):
        if src_graph.exists(i):
            dst_graph.add(src_graph.at(i))


def set_massless_except_trunk(robot: gtd.Robot) -> None:
    # Match C++ newGraph(): all links except trunk are zero mass/inertia.
    for link in robot.links():
        if 'trunk' not in link.name():
            link.setMass(0.0)
            link.setInertia(np.zeros((3, 3)))


def build_walk_trajectory(robot: gtd.Robot, repeat: int = 1) -> gtd.Trajectory:
    # Match getTrajectory() from C++: RRFL -> stationary -> RLFR -> stationary.
    rlfr = [robot.link('RL_lower'), robot.link('FR_lower')]
    rrfl = [robot.link('RR_lower'), robot.link('FL_lower')]
    all_feet = rlfr + rrfl

    contact_in_com = gtsam.Point3(0.0, 0.0, -0.07)
    stationary = gtd.FootContactConstraintSpec(all_feet, contact_in_com)
    rlfr_state = gtd.FootContactConstraintSpec(rlfr, contact_in_com)
    rrfl_state = gtd.FootContactConstraintSpec(rrfl, contact_in_com)

    states = [rrfl_state, stationary, rlfr_state, stationary]
    phase_lengths = [25, 5, 25, 5]

    walk_cycle = gtd.WalkCycle(states, phase_lengths)
    return gtd.Trajectory(walk_cycle, repeat)

## 3) Build the Chain-Dynamics Graph and Objectives

This cell follows the structure of `newGraph()` in C++:

1. Load A1, make non-trunk links massless.
2. Build `ChainDynamicsGraph`.
3. Build multi-phase dynamics graph from the walk trajectory.
4. Add contact-point objectives.
5. Add chain-specific boundary objectives.
6. Add base pose/twist objectives for all timesteps.
7. Constrain integration time (`dt = 1/20`).
8. Add joint angle priors (`lower`, `hip`, `upper`).

In [None]:
robot = gtd.CreateRobotFromFile(gtd.URDF_PATH + '/a1/a1.urdf', 'a1')
set_massless_except_trunk(robot)

sigma_dynamics = 1e-3
sigma_objectives = 1e-3

dynamics_model_6 = gtsam.noiseModel.Isotropic.Sigma(6, sigma_dynamics)
dynamics_model_1 = gtsam.noiseModel.Isotropic.Sigma(1, sigma_dynamics)
objectives_model_6 = gtsam.noiseModel.Isotropic.Sigma(6, sigma_objectives)
objectives_model_1 = gtsam.noiseModel.Isotropic.Sigma(1, sigma_objectives)

gravity = np.array([0.0, 0.0, -9.8])
mu = 1.0

opt = gtd.OptimizerSetting()
graph_builder = gtd.ChainDynamicsGraph(robot, opt, gravity)

trajectory = build_walk_trajectory(robot, repeat=1)
collocation = gtd.CollocationScheme.Euler

graph = trajectory.multiPhaseFactorGraph(robot, graph_builder, collocation, mu)

ground_height = 1.0
step = gtsam.Point3(0.25, 0.0, 0.0)
objectives = trajectory.contactPointObjectives(
    robot,
    gtsam.noiseModel.Isotropic.Sigma(3, 1e-6),
    step,
    ground_height,
)

K = trajectory.getEndTimeStep(trajectory.numPhases() - 1)

for link in robot.links():
    i = link.id()
    if i == 0:
        add_subgraph(
            objectives,
            gtd.LinkObjectives(i, 0)
               .pose(link.bMcom(), gtsam.noiseModel.Isotropic.Sigma(6, 1e-3))
               .twist(np.zeros(6), gtsam.noiseModel.Isotropic.Sigma(6, 1e-3)),
        )
    if i in (3, 6, 9, 12):
        add_subgraph(
            objectives,
            gtd.LinkObjectives(i, 0).pose(
                link.bMcom(), gtsam.noiseModel.Isotropic.Sigma(6, 1e-3)
            ),
        )

add_subgraph(objectives, gtd.JointsAtRestObjectives(robot, objectives_model_1, objectives_model_1, 0))
add_subgraph(objectives, gtd.JointsAtRestObjectives(robot, objectives_model_1, objectives_model_1, K))

trunk = robot.link('trunk')
for k in range(K + 1):
    add_subgraph(
        objectives,
        gtd.LinkObjectives(trunk.id(), k)
           .pose(
               gtsam.Pose3(gtsam.Rot3(), gtsam.Point3(0.0, 0.0, 0.4)),
               gtsam.noiseModel.Isotropic.Sigma(6, 1e-2),
           )
           .twist(np.zeros(6), gtsam.noiseModel.Isotropic.Sigma(6, 5e-2)),
    )

desired_dt = 1.0 / 20.0
trajectory.addIntegrationTimeFactors(objectives, desired_dt, 1e-30)

prior_model_angles = gtsam.noiseModel.Isotropic.Sigma(1, 1e-2)
prior_model_hip = gtsam.noiseModel.Isotropic.Sigma(1, 1e-2)

for joint in robot.joints():
    name = joint.name()
    jid = joint.id()
    for k in range(K + 1):
        if 'lower' in name:
            add_subgraph(objectives, gtd.JointObjectives(jid, k).angle(-1.4, prior_model_angles))
        if 'hip' in name:
            add_subgraph(objectives, gtd.JointObjectives(jid, k).angle(0.0, prior_model_hip))
        if 'upper' in name:
            add_subgraph(objectives, gtd.JointObjectives(jid, k).angle(0.7, prior_model_angles))

add_subgraph(graph, objectives)

print('Num phases           :', trajectory.numPhases())
print('Final timestep K     :', K)
print('Graph factors        :', graph.size())
print('Graph variable keys  :', graph.keys().size())

## 4) Initialize with `ChainInitializer` and Optimize

This mirrors the C++ initialization and Levenberg-Marquardt setup.

In [None]:
gaussian_noise = 1e-30
initializer = gtd.ChainInitializer()
init_values = trajectory.multiPhaseInitialValues(robot, initializer, gaussian_noise, desired_dt)

print('Initial values:', init_values.size())

params = gtsam.LevenbergMarquardtParams()
params.setlambdaInitial(1e10)
params.setlambdaLowerBound(1e-7)
params.setlambdaUpperBound(1e10)
params.setAbsoluteErrorTol(1.0)

optimizer = gtsam.LevenbergMarquardtOptimizer(graph, init_values, params)
result = optimizer.optimize()

print('Final objective error:', graph.error(result))

## 5) Convert the optimized `Values` to a table

We record joint angle/velocity/acceleration/torque over time. We also track trunk position to visualize forward progression.

In [None]:
def read_or_nan(values: gtsam.Values, key: int) -> float:
    return values.atDouble(key) if values.exists(key) else np.nan

rows = []
joints = list(robot.joints())
trunk_id = robot.link('trunk').id()

for k in range(K + 1):
    row = {'t': k * desired_dt}

    pose_k = gtd.Pose(result, trunk_id, k)
    row['trunk_x'] = pose_k.translation()[0]
    row['trunk_y'] = pose_k.translation()[1]
    row['trunk_z'] = pose_k.translation()[2]

    for joint in joints:
        j = joint.id()
        name = joint.name()
        row[name] = read_or_nan(result, gtd.JointAngleKey(j, k))
        row[f'{name}.1'] = read_or_nan(result, gtd.JointVelKey(j, k))
        row[f'{name}.2'] = read_or_nan(result, gtd.JointAccelKey(j, k))
        row[f'{name}.3'] = read_or_nan(result, gtd.TorqueKey(j, k))

    rows.append(row)

traj_df = pd.DataFrame(rows)
traj_df.head(3)

## 6) Save trajectory CSV

We save both:

- A pandas-friendly table with extra columns (for analysis/plotting).
- The native GTDynamics CSV format via `trajectory.writeToFile(...)`.

In [None]:
out_dir = repo_root / 'build' / 'examples' / 'example_a1_walking'
out_dir.mkdir(parents=True, exist_ok=True)

csv_table = out_dir / 'a1_traj_chain_dynamics_graph_python.csv'
traj_df.to_csv(csv_table, index=False)
print('Wrote table CSV:', csv_table)

# ChainInitializer does not necessarily populate every key that writeToFile()
# expects (notably torque keys). Create an export copy and fill missing entries.
export_values = gtsam.Values(result)

for joint in robot.joints():
    j = joint.id()
    for k in range(K + 1):
        tk = gtd.TorqueKey(j, k)
        if not export_values.exists(tk):
            export_values.insert(tk, 0.0)

for pidx in range(trajectory.numPhases()):
    pk = gtd.PhaseKey(pidx)
    if not export_values.exists(pk):
        export_values.insert(pk, desired_dt)

native_name = 'a1_traj_CDG_massless_python.csv'
trajectory.writeToFile(robot, native_name, export_values)
print('Wrote native GTDynamics CSV in current working directory:', native_name)


## 7) Plot joint trajectories and contact-point heights

The second figure shows foot contact-point heights (world Z) for all four feet, which helps check alternating stance/swing behavior.

In [None]:
legs = ['FL', 'FR', 'RL', 'RR']
joint_order = ['hip', 'upper', 'lower']
colors = {'hip': 'tab:blue', 'upper': 'tab:orange', 'lower': 'tab:green'}


def plot_joint_group(suffix: str, title: str, y_label: str):
    fig, axs = plt.subplots(4, 1, figsize=(12, 11), sharex=True)
    for i, leg in enumerate(legs):
        ax = axs[i]
        for joint in joint_order:
            col = f'{leg}_{joint}_joint{suffix}'
            if col in traj_df.columns:
                ax.plot(traj_df['t'], traj_df[col], label=joint, color=colors[joint], linewidth=1.4)
        ax.grid(alpha=0.3)
        ax.set_ylabel(f'{leg} {y_label}')
        ax.legend(loc='upper right', ncol=3)
    axs[-1].set_xlabel('time [s]')
    fig.suptitle(title)
    fig.tight_layout()
    plt.show()


plot_joint_group('', 'Joint Angles (ChainDynamicsGraph walk cycle)', 'q')
plot_joint_group('.1', 'Joint Velocities (ChainDynamicsGraph walk cycle)', 'qdot')

contact_in_com = gtsam.Point3(0.0, 0.0, -0.07)
foot_links = ['FL_lower', 'FR_lower', 'RL_lower', 'RR_lower']

foot_z = {'t': traj_df['t'].to_numpy()}
for name in foot_links:
    cp = gtd.PointOnLink(robot.link(name).shared(), contact_in_com)
    foot_z[name] = np.array([cp.predict(result, k)[2] for k in range(K + 1)])

fig, ax = plt.subplots(figsize=(12, 4))
for name in foot_links:
    ax.plot(foot_z['t'], foot_z[name], label=name, linewidth=1.5)
ax.grid(alpha=0.3)
ax.set_xlabel('time [s]')
ax.set_ylabel('contact point z [m]')
ax.set_title('Foot contact-point heights in world frame')
ax.legend(ncol=4)
plt.tight_layout()
plt.show()

## 8) Quick diagnostics

In [None]:
angle_cols = [c for c in traj_df.columns if c not in ['t', 'trunk_x', 'trunk_y', 'trunk_z'] and not c.endswith('.1') and not c.endswith('.2') and not c.endswith('.3')]
vel_cols = [c for c in traj_df.columns if c.endswith('.1')]
acc_cols = [c for c in traj_df.columns if c.endswith('.2')]

print('Finite ratio (angles):', np.isfinite(traj_df[angle_cols].to_numpy()).mean())
print('Finite ratio (vels)  :', np.isfinite(traj_df[vel_cols].to_numpy()).mean())
print('Finite ratio (accels):', np.isfinite(traj_df[acc_cols].to_numpy()).mean())
print('Trunk z mean/std     :', traj_df['trunk_z'].mean(), traj_df['trunk_z'].std())
print('Trunk x start/end    :', traj_df['trunk_x'].iloc[0], traj_df['trunk_x'].iloc[-1])