In [4]:
import glob
import math
from dataclasses import dataclass

import matplotlib.pyplot as plt
import numpy as np

from alpasim_grpc.v0.common_pb2 import Pose
from alpasim_grpc.v0.logging_pb2 import ActorPoses, RolloutMetadata
from alpasim_utils.trajectory import QVec, Trajectory
from alpasim_utils.scenario import AABB
from alpasim_utils.logs import async_read_pb_log

In [5]:
@dataclass
class ActorTrajectory:
    id: str
    bbox: AABB
    trajectory: Trajectory

@dataclass
class RolloutTrajectories:
    scene_id: str
    actor_trajectories: dict[str, ActorTrajectory]

async def read_trajectory(fname: str) -> RolloutTrajectories:
    """
    Read a log stream, select actor poses and combine in a trajectory object
    """
    timestamps_us = []
    poses = []

    scene_id: str | None = None
    raw_trajectories: dict[str, list[tuple[int, Pose]]]= {}
    aabbs: dict[str, AABB]= {
        "EGO": AABB(1., 1., 1.)
    }
    async for message in async_read_pb_log(fname):
        if message.WhichOneof('log_entry') == 'rollout_metadata':
            metadata: RolloutMetadata = message.rollout_metadata
            scene_id = metadata.session_metadata.scene_id
            for actor_aabb in metadata.actor_definitions.actor_aabb:
                aabb = actor_aabb.aabb
                aabbs[actor_aabb.actor_id] = AABB(x=aabb.size_x, y=aabb.size_y, z=aabb.size_z)
            continue
        elif message.WhichOneof('log_entry') == 'actor_poses':
            poses_message: ActorPoses = message.actor_poses
            timestamp_us = poses_message.timestamp_us
            for pose in poses_message.actor_poses:
                raw_trajectories.setdefault(pose.actor_id, []).append((timestamp_us, pose.actor_pose))
        else:
            pass # not using other messages for now

    if scene_id is None:
        raise RuntimeError("Did not find scene ID in log file (`rollout_metadata` missing).")
    
    trajectories: dict[str, ActorTrajectory] = {}
    for actor_id, raw_trajectory in raw_trajectories.items():
        timestamps_us = []
        poses = []
        for timestamp_us, pose in raw_trajectory:
            timestamps_us.append(timestamp_us)
            poses.append(QVec.from_grpc_pose(pose))

        trajectory = Trajectory(
            timestamps_us=np.array(timestamps_us, dtype=np.uint64),
            poses=QVec.stack(poses),
        )

        trajectories[actor_id] = ActorTrajectory(
            id=actor_id,
            bbox=aabbs[actor_id],
            trajectory=trajectory,
        )

    return RolloutTrajectories(scene_id=scene_id, actor_trajectories=trajectories)

In [6]:
glob_pattern = ("**/*.asl")
glob_pattern = ("/home/mwatson/Documents/alpamayo/alpasim/tutorial/asl
fnames = glob.glob(glob_pattern, recursive=True)
rollout_trajectories: list[RolloutTrajectories] = []
for fname in fnames:
    rollout_trajectory = await read_trajectory(fname)
    rollout_trajectories.append(rollout_trajectory)

In [7]:
scene_ids = set([t.scene_id for t in rollout_trajectories])
print(len(scene_ids))

0


## Plot rollout trajectories

In [None]:
rollout_trajectory = rollout_trajectories[0]

actor_names = rollout_trajectory.actor_trajectories.keys()
fig, ax = plt.subplots(1, 1, figsize=(5, 5))

for actor_name in actor_names:
    actor_trajectory = rollout_trajectory.actor_trajectories[actor_name]
    vec3 = actor_trajectory.trajectory.poses.vec3
    plt.plot(vec3[:, 0], vec3[:, 1], label=actor_name)


In [None]:
plot_side = int(math.ceil(math.sqrt(len(scene_ids))))

fig, axes = plt.subplots(plot_side, plot_side, figsize=(15, 15), constrained_layout=True, squeeze=False)

name_to_ax = {name: ax for name, ax in zip(scene_ids, axes.flat)}

for rollout_trajectory in rollout_trajectories:
    ax = name_to_ax[rollout_trajectory.scene_id]
    for actor_id, trajectory in rollout_trajectory.actor_trajectories.items():
        if actor_id == 'EGO':
            ax.plot(trajectory.trajectory.poses.vec3[..., 0], trajectory.trajectory.poses.vec3[..., 1], '-')
        else:
            ax.plot(trajectory.trajectory.poses.vec3[..., 0], trajectory.trajectory.poses.vec3[..., 1], '--')
    ax.plot()

for name, ax in name_to_ax.items():
    ax.set_box_aspect(1)
    ax.set_title(name)