In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from trajkit import TrajectorySet


In [None]:
# Load trajectory data from a remote CSV (replace with your URL)
csv_url = "https://your-database.org/path/to/res_xyti_time11.csv"

# Expect columns: id, x, y, t (frame or time). Adjust names below if needed.
df_whole = pd.read_csv(csv_url)
# Optional: downselect frames for a quick tutorial run
frame_max = 120
df = df_whole[df_whole["t"] < frame_max].copy()

# Build a TrajectorySet from the table
trajset = TrajectorySet.from_dataframe(
    df,
    dataset_id="cdv_demo",
    track_id_col="id",
    position_cols=["x", "y"],
    time_col="t",
    frame_col="t",
    units={"t": "frame", "x": "pixel"},
)
trajset.summary_table().head()


In [None]:
# Assemble per-frame displacements separated by delta_t

delta_t_frames = 1.0
time_tol = 1e-3

disp_rows = []
for tid, tr in trajset.trajectories.items():
    t = tr.time_seconds()
    if len(t) == 0:
        continue
    frames = tr.frame if tr.frame is not None else np.arange(len(t), dtype=int)
    targets = t + delta_t_frames
    idx = np.searchsorted(t, targets, side="left")
    for i, j in enumerate(idx):
        if j >= len(t):
            continue
        if abs(t[j] - targets[i]) > time_tol:
            continue
        dx = tr.x[j] - tr.x[i]
        row = {"track_id": tid, "t": float(t[i]), "frame": int(frames[i])}
        for k in range(tr.D):
            row[f"x{k}"] = float(tr.x[i, k])
            row[f"dx{k}"] = float(dx[k])
        disp_rows.append(row)

disp_df = pd.DataFrame(disp_rows)
print(f"Displacement rows: {len(disp_df)} (delta_t={delta_t_frames} frames)")
disp_df.head()


In [None]:
# CDV setup: grid, pair filter, and plotting helper
from trajkit.cdv import correlation_batch, CorrelationEnsembleAccumulator, distance_threshold_pair_filter

position_cols = [c for c in disp_df.columns if c.startswith("x") and not c.startswith("dx")]
motion_cols = [c for c in disp_df.columns if c.startswith("dx")]
max_pair_distance = 600  # drop pairs separated by more than this distance

# Grid in relative-position space
x = np.linspace(-400.0, 400.0, 100)
y = np.linspace(-400.0, 400.0, 100)
X, Y = np.meshgrid(x, y, indexing="xy")
grid_centers = np.stack([X.ravel(), Y.ravel()], axis=1)
pair_filter = distance_threshold_pair_filter(max_pair_distance)

from pathlib import Path
plot_dir = Path("results/cdv_frames_tutorial")
plot_dir.mkdir(parents=True, exist_ok=True)


def plot_cdv_frame(
    frame_idx,
    disp_temp,
    batch_rotated,
    ensemble_accumulator,
    X,
    Y,
    *,
    save_dir=None,
    quiver_stride=5,
    motion_quiver_scale=1.0,
    flow_quiver_scale=80,
):
    """Plot displacements, rotated correlations, and current ensemble flow for one frame."""
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    ax_disp, ax_corr, ax_flow = axes

    # Current-frame displacements
    ax_disp.quiver(
        disp_temp["x0"],
        disp_temp["x1"],
        disp_temp["dx0"],
        disp_temp["dx1"],
        angles="xy",
        scale_units="xy",
        scale=motion_quiver_scale,
        color="tab:blue",
        alpha=0.9,
    )
    ax_disp.scatter(disp_temp["x0"], disp_temp["x1"], s=16, color="tab:blue", alpha=0.9)
    ax_disp.set_title(f"Displacements (frame {frame_idx})")
    ax_disp.set_aspect("equal")

    # Rotated correlations
    ax_corr.quiver(
        batch_rotated.relative_positions[:, 0],
        batch_rotated.relative_positions[:, 1],
        batch_rotated.tracer_motion[:, 0],
        batch_rotated.tracer_motion[:, 1],
        angles="xy",
        scale_units="xy",
        scale=motion_quiver_scale,
        color="tab:orange",
        alpha=0.6,
    )
    ax_corr.set_title("Rotated tracer motion vs. rel. pos.")
    ax_corr.set_aspect("equal")
    ax_corr.set_xlabel("r1")
    ax_corr.set_ylabel("r2")

    # Ensemble flow so far
    mean, _, _ = ensemble_accumulator.finalize()
    U = mean[:, 0].reshape(X.shape)
    V = mean[:, 1].reshape(Y.shape)
    mag = np.ma.array(np.hypot(U, V), mask=np.isnan(U) | np.isnan(V))

    cf = ax_flow.contourf(X, Y, mag, levels=20, cmap="YlGnBu", alpha=0.85)
    step = max(1, quiver_stride)
    ax_flow.quiver(
        X[::step, ::step],
        Y[::step, ::step],
        np.nan_to_num(U)[::step, ::step],
        np.nan_to_num(V)[::step, ::step],
        color="white",
        scale=flow_quiver_scale,
        alpha=0.9,
    )
    ax_flow.set_title(f"Ensemble flow <= frame {frame_idx}")
    ax_flow.set_aspect("equal")
    cbar = fig.colorbar(cf, ax=ax_flow, shrink=0.8, pad=0.02)
    cbar.set_label("|flow|")

    for ax in axes:
        ax.grid(True, linestyle=":", alpha=0.3)
        ax.set_xlabel("x")
        ax.set_ylabel("y")

    fig.tight_layout()
    save_path = None
    if save_dir is not None:
        save_path = Path(save_dir) / f"frame_{frame_idx:03d}.png"
        fig.savefig(save_path, dpi=200)
    plt.close(fig)
    return save_path


In [None]:
# Run CDV loop

frame_start = 0
frame_stop = 60  # adjust for your dataset (or use disp_df["frame"].max())

ensemble = CorrelationEnsembleAccumulator(
    grid_centers,
    kernel=30.0,  # hard cutoff radius in rel_pos space
    value_fn=lambda rel, tracer, source, meta_row: tracer,
    weight_fn=lambda rel, tracer, source, meta_row: np.linalg.norm(source),
)

for i in range(frame_start, frame_stop):
    disp_temp = disp_df[disp_df["frame"] == i]
    if disp_temp.empty:
        continue
    # Optional region-of-interest filter; comment out to use all pairs
    disp_active = disp_temp[disp_temp["x0"].between(300, 1100) & disp_temp["x1"].between(300, 1100)]
    if disp_active.empty:
        continue

    batch, _ = correlation_batch(
        disp_active,
        disp_temp,
        source_frame_col="frame", tracer_frame_col="frame",
        source_position_cols=position_cols, tracer_position_cols=position_cols,
        source_motion_cols=motion_cols, tracer_motion_cols=motion_cols,
        pair_filter=pair_filter,
    )
    batch_r = batch.rotate_to_source_x()
    ensemble.add(batch_r)
    plot_cdv_frame(
        i,
        disp_temp,
        batch_r,
        ensemble,
        X,
        Y,
        save_dir=plot_dir,
        quiver_stride=6,
    )

mean, sum_w, counts = ensemble.finalize()
mean[:5]
