In [68]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.patches import Rectangle
from IPython.display import HTML
import sys

# Add path to GP modules
sys.path.insert(0, '/home/egmc/ws_px4_rta_mm_gpr/src/px4_rta_mm_gpr/px4_rta_mm_gpr/jax_mm_rta')
import jax.numpy as jnp
from TVGPR import TVGPR  # Time-Varying Gaussian Process Regression

In [69]:
# Configure matplotlib for animation in Jupyter
%matplotlib notebook

plt.rcParams.update({
    "text.usetex": False,
    "font.family": "sans-serif",
    "font.size": 14
})

In [70]:
# Load the log file
# testing_NOwind_estimation_with17Ywind.log
log_file = 'log_files/testing_NOwind_estimation_with17Ywind.log'
data = pd.read_csv(log_file)

print("Columns:", data.columns.tolist())
print(f"Total rows: {len(data)}")
print(f"\nFirst few rows:")
print(data.head())

Columns: ['time', 'x', 'y', 'z', 'yaw', 'ctrl_comp_time', 'rollout_comptime', 'y_ref', 'z_ref', 'yaw_ref', 'throttle', 'roll_rate', 'pitch_rate', 'yaw_rate', 'save_tube_pyH', 'save_tube_pyL', 'save_tube_pzH', 'save_tube_pzL', 'wy', 'wz']
Total rows: 13332

First few rows:
        time         x         y          z       yaw  ctrl_comp_time  \
0  15.004004 -0.035847  4.219260 -12.427777  0.001232        0.176533   
1  15.446543 -0.035847  4.219260 -12.427777  0.001232        0.157226   
2  15.578314 -0.043350  4.205662 -12.445257  0.001495        0.002068   
3  15.585445 -0.043350  4.205662 -12.445257  0.001495        0.002305   
4  15.593912 -0.045743  4.205756 -12.445498  0.001524        0.002017   

   rollout_comptime     y_ref      z_ref   yaw_ref  throttle  roll_rate  \
0          0.119684  4.210073 -12.410289 -0.333022  0.603607  -0.969025   
1          0.119684  4.204127 -12.400228 -0.353022  0.602500  -1.000000   
2          0.109661  4.197543 -12.388531 -0.373022  0.601345  -

In [71]:
def get_RMSE(data, num_refs_per_step=12):
    """
    Calculate RMSE between actual position and reference trajectory.
    
    Args:
        data: DataFrame with columns y, z, y_ref, z_ref
        num_refs_per_step: Number of reference/tube saves per timestep (default: 12)
    
    Returns:
        RMSE value (combined y and z error)
    """
    y     = data["y"].to_numpy()
    z     = data["z"].to_numpy()
    y_ref = data["y_ref"].to_numpy()
    z_ref = data["z_ref"].to_numpy()

    n_true  = len(y)
    n_ref   = len(y_ref)

    squared_errors = []

    n = 0
    while True:
        # Stop if we've exhausted true data
        if n >= n_true:
            break

        # Stop once true y/z "run out" and become NaNs
        if np.isnan(y[n]) or np.isnan(z[n]):
            break

        # Reference index: at timestep n, use the last reference from that timestep's group
        # The num_refs_per_step references for timestep n are at rows: num_refs_per_step*n to num_refs_per_step*n + (num_refs_per_step-1)
        # We want the last one: num_refs_per_step*n + (num_refs_per_step-1)
        ref_idx = num_refs_per_step * n# + (num_refs_per_step - 1)

        # Stop if reference index is out of range
        if ref_idx >= n_ref:
            break

        # If reference is NaN, just skip this n and continue
        if np.isnan(y_ref[ref_idx]) or np.isnan(z_ref[ref_idx]):
            n += 1
            continue

        dy = y[n] - y_ref[ref_idx]
        dz = z[n] - z_ref[ref_idx]
        squared_errors.append(dy*dy + dz*dz)

        n += 1

    if len(squared_errors) == 0:
        return np.nan  # or raise an error if you prefer

    rmse_combined = np.sqrt(np.mean(squared_errors))
    return rmse_combined

In [72]:
# Calculate RMSE with 12 references per timestep
rmse = get_RMSE(data, num_refs_per_step=12)
print(f"RMSE (12 refs per step): {rmse:.6f}")
rmse

RMSE (12 refs per step): 0.094602


np.float64(0.09460167725407102)

In [73]:
# Data structure:
# - At timestep n: actual state (time, y, z, yaw) comes from row n
# - At timestep n: 12 forward-looking references/tubes come from rows 12*n to 12*n+11
# - This is why data appears to end when 12*(n+1) > total_rows

# Number of forward-looking predictions per time step
num_horizon = 12

# Calculate maximum time step we can animate
# At time step n, we need reference rows up to 12*n+11
# So max n where 12*(n+1) <= len(data)
max_timestep = len(data) // num_horizon
print(f"Total rows in data: {len(data)}")
print(f"Maximum animatable timestep: {max_timestep} (can animate n=0 to n={max_timestep-1})")

# Extract raw data - ALL rows contain actual state data
time_all = data['time'].values
y_all = data['y'].values  
z_all = -data['z'].values  # FLIP Z to make it positive up
yaw_all = data['yaw'].values

# Extract wind data
wy_all = data['wy'].values
wz_all = data['wz'].values

# Reference and tube data - organized by timestep (12 rows per timestep)
y_ref_all = data['y_ref'].values
z_ref_all = -data['z_ref'].values  # FLIP Z

# Tube bounds - FLIP Z and swap high/low
pyH_all = data['save_tube_pyH'].values
pyL_all = data['save_tube_pyL'].values  
pzH_all = -data['save_tube_pzL'].values  # FLIP and SWAP
pzL_all = -data['save_tube_pzH'].values  # FLIP and SWAP

print(f"\n=== VERIFICATION ===")
print(f"At n=0:")
print(f"  Actual state from row 0:")
print(f"    Time: {time_all[0]:.6f}")
print(f"    Position: y={y_all[0]:.6f}, z={z_all[0]:.6f}, yaw={yaw_all[0]:.6f}")
print(f"    Wind: wy={wy_all[0]:.6f}, wz={wz_all[0]:.6f}")
print(f"  {num_horizon} References from rows 0-{num_horizon-1}:")
print(f"    y_ref: {y_ref_all[0:num_horizon]}")
print(f"    z_ref: {z_ref_all[0:num_horizon]}")

print(f"\nAt n=1:")
print(f"  Actual state from row 1:")
print(f"    Time: {time_all[1]:.6f}")
print(f"    Position: y={y_all[1]:.6f}, z={z_all[1]:.6f}, yaw={yaw_all[1]:.6f}")
print(f"    Wind: wy={wy_all[1]:.6f}, wz={wz_all[1]:.6f}")
print(f"  {num_horizon} References from rows {num_horizon}-{2*num_horizon-1}:")
print(f"    y_ref: {y_ref_all[num_horizon:2*num_horizon]}")
print(f"    z_ref: {z_ref_all[num_horizon:2*num_horizon]}")

if max_timestep > 0:
    last_n = max_timestep - 1
    print(f"\nAt n={last_n} (last valid timestep):")
    print(f"  Actual state from row {last_n}:")
    print(f"    Time: {time_all[last_n]:.6f}")
    print(f"    Position: y={y_all[last_n]:.6f}, z={z_all[last_n]:.6f}, yaw={yaw_all[last_n]:.6f}")
    print(f"  {num_horizon} References from rows {num_horizon*last_n}-{num_horizon*last_n+num_horizon-1}:")
    print(f"    y_ref: {y_ref_all[num_horizon*last_n:num_horizon*last_n+num_horizon]}")

Total rows in data: 13332
Maximum animatable timestep: 1111 (can animate n=0 to n=1110)

=== VERIFICATION ===
At n=0:
  Actual state from row 0:
    Time: 15.004004
    Position: y=4.219260, z=12.427777, yaw=0.001232
    Wind: wy=0.000000, wz=0.000000
  12 References from rows 0-11:
    y_ref: [4.21007289 4.20412718 4.1975435  4.19030676 4.18240347 4.17382194
 4.16455238 4.15458712 4.14392069 4.13255003 4.12047458 4.10769644]
    z_ref: [12.41028937 12.40022767 12.38853054 12.3751672  12.36010412 12.34330516
 12.32473161 12.30434227 12.2820936  12.25793979 12.2318329  12.20372295]

At n=1:
  Actual state from row 1:
    Time: 15.446543
    Position: y=4.219260, z=12.427777, yaw=0.001232
    Wind: wy=0.000000, wz=0.000000
  12 References from rows 12-23:
    y_ref: [4.20717878 4.20091606 4.19400764 4.18643921 4.17819814 4.16927369
 4.15965711 4.1493418  4.13832349 4.12660034 4.11417311 4.10104527]
    z_ref: [12.40546113 12.3945854  12.38205918 12.36785032 12.35192399 12.33424273
 12.31

In [74]:
# Helper function to draw a planar quadrotor
def quadplot(ypos, zpos, rot, scale):
    """Draw a 2D quadrotor shape at given position and orientation"""
    ts = -0.5 * np.arange(np.pi, 3/2 * np.pi, 0.2) + 0.3
    xtemp = np.cos(ts)
    ytemp = np.sin(ts) * np.cos(ts)

    xs = np.hstack((0.4*xtemp - 1, -1, -1, 1, 1, 0.4*xtemp + 1))
    ys = np.hstack((0.3*ytemp + 0.4, 0.4, 0, 0, 0.4, 0.3*ytemp + 0.4))
    xs = scale * xs
    ys = scale * ys
    
    rotmat = np.array([[np.cos(rot), -np.sin(rot)], 
                       [np.sin(rot), np.cos(rot)]])
    newpos = rotmat @ np.vstack((xs, ys))
    xs = newpos[0, :] + ypos
    ys = newpos[1, :] + zpos
    
    return xs, ys


# Build Time-Varying Gaussian Process models for wind forces
# wy as a function of (time, height): input=[time, z], output=wy  
# wz as a function of (time, y position): input=[time, y], output=wz

# We'll build the GP progressively as we animate, starting with initial observations
# For now, create initial observations from the first few timesteps

# Number of initial observations to seed the TVGPR
n_init_obs = 20
init_indices = np.arange(0, min(n_init_obs, max_timestep))

# Filter out NaN values for initial observations
valid_mask_wy = ~(np.isnan(time_all[init_indices]) | np.isnan(z_all[init_indices]) | np.isnan(wy_all[init_indices]))
valid_mask_wz = ~(np.isnan(time_all[init_indices]) | np.isnan(y_all[init_indices]) | np.isnan(wz_all[init_indices]))

# Build initial observation arrays for TVGPR: [time, position, output]
# For wy(t, z): time and z are inputs, wy is output
init_obs_wy = np.column_stack((time_all[init_indices][valid_mask_wy],
                               z_all[init_indices][valid_mask_wy], 
                               wy_all[init_indices][valid_mask_wy]))

# For wz(t, y): time and y are inputs, wz is output  
init_obs_wz = np.column_stack((time_all[init_indices][valid_mask_wz],
                               y_all[init_indices][valid_mask_wz],
                               wz_all[init_indices][valid_mask_wz]))

print(f"Building TVGPR for wy(t,z) with {len(init_obs_wy)} initial observations")
print(f"  time range: [{np.min(init_obs_wy[:,0]):.2f}, {np.max(init_obs_wy[:,0]):.2f}]")
print(f"  z range: [{np.min(init_obs_wy[:,1]):.2f}, {np.max(init_obs_wy[:,1]):.2f}]")
print(f"  wy range: [{np.min(init_obs_wy[:,2]):.2f}, {np.max(init_obs_wy[:,2]):.2f}]")

print(f"\nBuilding TVGPR for wz(t,y) with {len(init_obs_wz)} initial observations")
print(f"  time range: [{np.min(init_obs_wz[:,0]):.2f}, {np.max(init_obs_wz[:,0]):.2f}]")
print(f"  y range: [{np.min(init_obs_wz[:,1]):.2f}, {np.max(init_obs_wz[:,1]):.2f}]")
print(f"  wz range: [{np.min(init_obs_wz[:,2]):.2f}, {np.max(init_obs_wz[:,2]):.2f}]")

# Create initial TVGPR models (will be updated during animation)
tvgp_wy_init = TVGPR(jnp.array(init_obs_wy), sigma_f=5.0, l=2.0, sigma_n=0.01, epsilon=0.25)
tvgp_wz_init = TVGPR(jnp.array(init_obs_wz), sigma_f=5.0, l=2.0, sigma_n=0.01, epsilon=0.25)

print("\nTime-Varying GP models created successfully!")

Building TVGPR for wy(t,z) with 20 initial observations
  time range: [15.00, 15.87]
  z range: [12.43, 12.45]
  wy range: [0.00, 0.00]

Building TVGPR for wz(t,y) with 20 initial observations
  time range: [15.00, 15.87]
  y range: [4.21, 4.22]
  wz range: [0.00, 0.00]

Time-Varying GP models created successfully!


In [75]:
# ====== WIND ARROW CONTROL PARAMETERS ======
SHOW_WIND_Y = True   # Show gray arrows for wind in Y direction
SHOW_WIND_Z = True   # Show orange arrows for wind in Z direction (toggle off if only want Y)
WIND_Y_SCALE = 15    # Scale for Y-direction arrows (smaller = larger arrows)
WIND_Z_SCALE = 20    # Scale for Z-direction arrows (smaller = larger arrows)
# ============================================

# Set up the figure and axis
fig, ax = plt.subplots(figsize=(14, 10))

FONTSIZE = 18
LINEWIDTH = 2.5
QUADSIZE = 0.25

# Set axis limits based on data (handle NaNs)
y_valid = y_all[~np.isnan(y_all)]
z_valid = z_all[~np.isnan(z_all)]
y_margin = (np.max(y_valid) - np.min(y_valid)) * 0.15
z_margin = (np.max(z_valid) - np.min(z_valid)) * 0.15
y_min, y_max = np.min(y_valid) - y_margin, np.max(y_valid) + y_margin
z_min, z_max = np.min(z_valid) - z_margin, np.max(z_valid) + z_margin

# Create wind arrow grid
# Grid for wind visualization - create evenly spaced points
n_arrows_y = 8  # number of arrows in y direction
n_arrows_z = 10  # number of arrows in z direction
arrow_y_coords = np.linspace(y_min, y_max, n_arrows_y)
arrow_z_coords = np.linspace(z_min, z_max, n_arrows_z)
arrow_grid_y, arrow_grid_z = np.meshgrid(arrow_y_coords, arrow_z_coords)
arrow_points = np.column_stack((arrow_grid_y.ravel(), arrow_grid_z.ravel()))

# Calculate initial wind directions using TVGPR at t=time_all[0]
t_init = time_all[0]
arrow_dirs_y = np.zeros_like(arrow_points)
arrow_dirs_z = np.zeros_like(arrow_points)

for i, (y_pt, z_pt) in enumerate(arrow_points):
    # Wind in Y direction depends on (time, height z)
    if SHOW_WIND_Y:
        wy_pred = float(tvgp_wy_init.mean(jnp.array([t_init, z_pt]))[0][0])
        arrow_dirs_y[i, :] = [wy_pred, 0.0]
    
    # Wind in Z direction depends on (time, y position)
    if SHOW_WIND_Z:
        wz_pred = float(tvgp_wz_init.mean(jnp.array([t_init, y_pt]))[0][0])
        arrow_dirs_z[i, :] = [0.0, wz_pred]

# Plot wind arrows (will be updated dynamically)
if SHOW_WIND_Y:
    wind_arrows_y = ax.quiver(arrow_points[:, 0], arrow_points[:, 1], 
                              arrow_dirs_y[:, 0], arrow_dirs_y[:, 1],
                              color='tab:gray', alpha=0.5, label='Wind Y', scale=WIND_Y_SCALE)
else:
    wind_arrows_y = None

if SHOW_WIND_Z:
    wind_arrows_z = ax.quiver(arrow_points[:, 0], arrow_points[:, 1],
                              arrow_dirs_z[:, 0], arrow_dirs_z[:, 1], 
                              color='tab:orange', alpha=0.5, label='Wind Z', scale=WIND_Z_SCALE)
else:
    wind_arrows_z = None

# Initialize plot elements
# Actual trajectory - solid blue line showing path taken
actual_line, = ax.plot([], [], label='Actual Trajectory', 
                       color='blue', linewidth=LINEWIDTH)

# Reference trajectory - dashed line showing future horizon
reference_horizon, = ax.plot([], [], label=f'Reference Horizon ({num_horizon}-step)', 
                             linestyle='dashed', color='green', 
                             linewidth=LINEWIDTH, alpha=0.8, 
                             marker='o', markersize=6)

# Reachable tube rectangles - num_horizon rectangles along the future horizon
tube_rects = []
for i in range(num_horizon):
    rect = Rectangle((0, 0), 1, 1, linewidth=1.5, 
                    edgecolor='red', facecolor='red', alpha=0.15)
    ax.add_patch(rect)
    tube_rects.append(rect)
    if i == 0:
        rect.set_label('Reachable Tube Horizon')

# Quadrotor representation
quad_line, = ax.plot([], [], color='black', linewidth=LINEWIDTH)

# Current position marker
current_pos, = ax.plot([], [], 'ro', markersize=10, label='Current Position')

# Set up the plot
ax.set_xlabel('Y Position (m)', fontsize=FONTSIZE)
ax.set_ylabel('Z Position (m)', fontsize=FONTSIZE)
ax.set_title('Planar Quadrotor with Time-Varying Wind Field', fontsize=FONTSIZE+2)

# Move legend outside the plot area to the right
ax.legend(loc='center left', bbox_to_anchor=(1.02, 0.5), fontsize=FONTSIZE - 4)

ax.grid(True, alpha=0.3)
ax.set_aspect('equal', adjustable='box')

ax.set_xlim(y_min, y_max)
ax.set_ylim(z_min, z_max)

# Move time text to bottom left corner (outside main animation area)
time_text = ax.text(0.02, 0.02, '', transform=ax.transAxes, 
                    fontsize=FONTSIZE-2, verticalalignment='bottom',
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))

# Move step text to bottom right corner (outside main animation area)
step_text = ax.text(0.98, 0.02, '', transform=ax.transAxes,
                   fontsize=FONTSIZE-4, verticalalignment='bottom',
                   horizontalalignment='right',
                   bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))

# Adjust layout to make room for legend
plt.tight_layout()

wind_status = []
if SHOW_WIND_Y:
    wind_status.append("Y")
if SHOW_WIND_Z:
    wind_status.append("Z")
wind_status_str = "+".join(wind_status) if wind_status else "None"
print(f"Figure initialized with time-varying wind arrows: {wind_status_str}")

<IPython.core.display.Javascript object>

Figure initialized with time-varying wind arrows: Y+Z


In [76]:
# Animation update function with time-varying GP wind updates
# Keep a running window of observations for TVGPR
obs_window_size = 50  # Number of recent observations to keep in the GP

def update(n):
    """Update animation at time step n"""
    global tvgp_wy_init, tvgp_wz_init  # We'll update these as we collect observations
    
    # Build return list based on which wind arrows are enabled
    return_list = [actual_line, reference_horizon, quad_line, current_pos, time_text, step_text]
    if SHOW_WIND_Y and wind_arrows_y is not None:
        return_list.append(wind_arrows_y)
    if SHOW_WIND_Z and wind_arrows_z is not None:
        return_list.append(wind_arrows_z)
    return_list.extend(tube_rects)
    
    # Skip if we're beyond the valid range
    if n >= max_timestep or np.isnan(time_all[n]):
        return return_list
    
    # 1. Update actual trajectory (path taken up to current time)
    actual_line.set_data(y_all[:n+1], z_all[:n+1])
    
    # 2. Get current actual position and time from row n
    t_curr = time_all[n]
    y_curr = y_all[n]
    z_curr = z_all[n]
    yaw_curr = yaw_all[n]
    
    # 3. Update TVGPR with observations collected so far (use a sliding window)
    # Take observations from the most recent window
    start_idx = max(0, n - obs_window_size)
    obs_indices = np.arange(start_idx, n + 1)
    
    # Filter out NaN values
    valid_mask_wy = ~(np.isnan(time_all[obs_indices]) | np.isnan(z_all[obs_indices]) | np.isnan(wy_all[obs_indices]))
    valid_mask_wz = ~(np.isnan(time_all[obs_indices]) | np.isnan(y_all[obs_indices]) | np.isnan(wz_all[obs_indices]))
    
    if SHOW_WIND_Y and np.sum(valid_mask_wy) > 5:  # Only update if we have enough observations
        current_obs_wy = np.column_stack((time_all[obs_indices][valid_mask_wy],
                                         z_all[obs_indices][valid_mask_wy],
                                         wy_all[obs_indices][valid_mask_wy]))
        tvgp_wy_init = TVGPR(jnp.array(current_obs_wy), sigma_f=5.0, l=2.0, sigma_n=0.01, epsilon=0.25)
    
    if SHOW_WIND_Z and np.sum(valid_mask_wz) > 5:  # Only update if we have enough observations
        current_obs_wz = np.column_stack((time_all[obs_indices][valid_mask_wz],
                                         y_all[obs_indices][valid_mask_wz],
                                         wz_all[obs_indices][valid_mask_wz]))
        tvgp_wz_init = TVGPR(jnp.array(current_obs_wz), sigma_f=5.0, l=2.0, sigma_n=0.01, epsilon=0.25)
    
    # 4. Update wind arrows based on current time and TVGPR predictions
    if SHOW_WIND_Y and wind_arrows_y is not None:
        arrow_dirs_y_new = np.zeros_like(arrow_points)
        for i, (y_pt, z_pt) in enumerate(arrow_points):
            # Wind in Y direction at current time and height z
            wy_pred = float(tvgp_wy_init.mean(jnp.array([t_curr, z_pt]))[0][0])
            arrow_dirs_y_new[i, :] = [wy_pred, 0.0]
        # Update the quiver plot with new wind directions
        wind_arrows_y.set_UVC(arrow_dirs_y_new[:, 0], arrow_dirs_y_new[:, 1])
    
    if SHOW_WIND_Z and wind_arrows_z is not None:
        arrow_dirs_z_new = np.zeros_like(arrow_points)
        for i, (y_pt, z_pt) in enumerate(arrow_points):
            # Wind in Z direction at current time and y position
            wz_pred = float(tvgp_wz_init.mean(jnp.array([t_curr, y_pt]))[0][0])
            arrow_dirs_z_new[i, :] = [0.0, wz_pred]
        # Update the quiver plot with new wind directions
        wind_arrows_z.set_UVC(arrow_dirs_z_new[:, 0], arrow_dirs_z_new[:, 1])
    
    # 5. Get num_horizon forward-looking references from rows num_horizon*n to num_horizon*n+(num_horizon-1)
    ref_start = num_horizon * n
    ref_end = ref_start + num_horizon
    
    if ref_end <= len(data):
        y_refs = y_ref_all[ref_start:ref_end]
        z_refs = z_ref_all[ref_start:ref_end]
        pyH_refs = pyH_all[ref_start:ref_end]
        pyL_refs = pyL_all[ref_start:ref_end]
        pzH_refs = pzH_all[ref_start:ref_end]
        pzL_refs = pzL_all[ref_start:ref_end]
        
        # Update reference horizon line
        reference_horizon.set_data(y_refs, z_refs)
        
        # Update tube rectangles
        for i in range(num_horizon):
            if not (np.isnan(pyH_refs[i]) or np.isnan(pyL_refs[i]) or 
                   np.isnan(pzH_refs[i]) or np.isnan(pzL_refs[i])):
                tube_width = pyH_refs[i] - pyL_refs[i]
                tube_height = pzH_refs[i] - pzL_refs[i]
                
                # Only draw if dimensions are reasonable
                if tube_width > 0 and tube_height > 0 and tube_width < 50 and tube_height < 50:
                    tube_rects[i].set_xy((pyL_refs[i], pzL_refs[i]))
                    tube_rects[i].set_width(tube_width)
                    tube_rects[i].set_height(tube_height)
                    tube_rects[i].set_visible(True)
                else:
                    tube_rects[i].set_visible(False)
            else:
                tube_rects[i].set_visible(False)
    else:
        # No valid references
        reference_horizon.set_data([], [])
        for rect in tube_rects:
            rect.set_visible(False)
    
    # 6. Update quadrotor position and orientation
    if not (np.isnan(y_curr) or np.isnan(z_curr)):
        quad_y, quad_z = quadplot(y_curr, z_curr, yaw_curr, QUADSIZE)
        quad_line.set_data(quad_y, quad_z)
        quad_line.set_visible(True)
        
        # Update current position marker
        current_pos.set_data([y_curr], [z_curr])
        current_pos.set_visible(True)
    else:
        quad_line.set_visible(False)
        current_pos.set_visible(False)
    
    # 7. Update text displays
    time_text.set_text(f't = {t_curr:.3f} s')
    step_text.set_text(f'Step n = {n}\nRefs: rows {ref_start}-{ref_end-1}')
    
    return return_list

In [None]:
# ====== ANIMATION CONTROL PARAMETERS ======
# Adjust FRAME_SKIP to control animation speed and rendering time
# Higher values = faster rendering but choppier animation
# FRAME_SKIP = 1  : animate every frame (slowest render, smoothest animation)
# FRAME_SKIP = 5  : animate every 5th frame (faster render)
# FRAME_SKIP = 10 : animate every 10th frame (fastest render, choppiest)

FRAME_SKIP = 10 # <-- CHANGE THIS VALUE to skip frames

# ==========================================

# Determine frame skip for smooth animation
# Calculate time steps per second in the data
time_valid = time_all[:max_timestep]
time_valid = time_valid[~np.isnan(time_valid)]

if len(time_valid) > 1:
    time_diffs = np.diff(time_valid)
    valid_diffs = time_diffs[time_diffs > 0]
    if len(valid_diffs) > 0:
        avg_dt = np.median(valid_diffs)
    else:
        avg_dt = 0.1
else:
    avg_dt = 0.1

print(f"Average time between steps: {avg_dt:.4f} seconds")
print(f"Frame skip setting: {FRAME_SKIP}")

# Create frame list with skipping applied
frames_to_use = list(range(0, max_timestep, FRAME_SKIP))
interval_ms = 50  # 50ms between frames = 20 fps

print(f"Total timesteps available: {max_timestep}")
print(f"Total frames to animate: {len(frames_to_use)}")
print(f"Animation interval: {interval_ms} ms")
print(f"Estimated animation duration: {len(frames_to_use) * interval_ms / 1000:.1f} seconds")

# Create the animation
print("\nCreating animation...")
ani = animation.FuncAnimation(fig, update, frames=frames_to_use,
                             interval=interval_ms, blit=True, repeat=True)

# Display as HTML
HTML(ani.to_jshtml())

Average time between steps: 0.0100 seconds
Frame skip setting: 10
Total timesteps available: 1111
Total frames to animate: 112
Animation interval: 50 ms
Estimated animation duration: 5.6 seconds

Creating animation...


In [None]:
# OPTIONAL: Save the animation as a GIF
# Uncomment and run this cell only if you want to save the animation

# print("Saving animation as GIF (this may take a few minutes)...")
# ani.save('log_animation1.gif', writer='pillow', fps=10)
# print("Animation saved as 'log_animation1.gif'!")