In [None]:
import tensorflow as tf
import numpy as np
import os
from tqdm import tqdm

def parse_tfrecord(record_bytes):
    """Parse TFRecord using the correct feature description."""

    feature_description = {
        # State features for 128 agents
        "state/id": tf.io.FixedLenFeature([128], tf.float32, default_value=None),
        "state/type": tf.io.FixedLenFeature([128], tf.float32, default_value=None),
        "state/is_sdc": tf.io.FixedLenFeature([128], tf.int64, default_value=None),
        "state/tracks_to_predict": tf.io.FixedLenFeature([128], tf.int64, default_value=None),

        # Current state (shape: [128, 1])
        "state/current/x": tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
        "state/current/y": tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
        "state/current/bbox_yaw": tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
        "state/current/valid": tf.io.FixedLenFeature([128, 1], tf.int64, default_value=None),
        "state/current/width": tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
        "state/current/length": tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),

        # Past state (shape: [128, 10])
        "state/past/x": tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
        "state/past/y": tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
        "state/past/bbox_yaw": tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
        "state/past/valid": tf.io.FixedLenFeature([128, 10], tf.int64, default_value=None),

        # Future state (shape: [128, 80])
        "state/future/x": tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
        "state/future/y": tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
        "state/future/bbox_yaw": tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
        "state/future/valid": tf.io.FixedLenFeature([128, 80], tf.int64, default_value=None),

        # Scenario ID
        "scenario/id": tf.io.FixedLenFeature([1], tf.string, default_value=None),
    }

    return tf.io.parse_single_example(record_bytes, feature_description)

def process_trajectories(tfrecord_path, output_path, max_samples=None):
    """Process TFRecord file and save as NPZ."""

    # Create TFRecord dataset
    dataset = tf.data.TFRecordDataset(tfrecord_path)

    # Lists to store all scenes
    all_scenes = []

    # Process each example (scene)
    for i, record_bytes in tqdm(enumerate(dataset)):
        if max_samples and i >= max_samples:
            break

        try:
            # Parse the record
            example = parse_tfrecord(record_bytes)

            # Process each agent in the scene
            n_agents = 128  # Fixed size from your feature description

            # Combine past and current for history
            history_x = np.concatenate([
                example['state/past/x'].numpy(),
                example['state/current/x'].numpy()
            ], axis=1)  # Shape: [128, 11]

            history_y = np.concatenate([
                example['state/past/y'].numpy(),
                example['state/current/y'].numpy()
            ], axis=1)  # Shape: [128, 11]

            history_valid = np.concatenate([
                example['state/past/valid'].numpy(),
                example['state/current/valid'].numpy()
            ], axis=1)  # Shape: [128, 11]

            # Get future trajectories
            future_x = example['state/future/x'].numpy()  # Shape: [128, 80]
            future_y = example['state/future/y'].numpy()  # Shape: [128, 80]
            future_valid = example['state/future/valid'].numpy()  # Shape: [128, 80]

            # Create scene data
            scene_data = {
                'file_name': os.path.basename(tfrecord_path),
                'scenario_id': example['scenario/id'].numpy()[0].decode(),
                'agent_id': example['state/id'].numpy(),
                'agent_type': example['state/type'].numpy(),
                'agent_valid': example['state/current/valid'].numpy()[:, 0],
                'width': example['state/current/width'].numpy()[:, 0],
                'length': example['state/current/length'].numpy()[:, 0],
                'history/xy': np.stack([history_x, history_y], axis=2),  # Shape: [128, 11, 2]
                'history/yaw': np.concatenate([
                    example['state/past/bbox_yaw'].numpy(),
                    example['state/current/bbox_yaw'].numpy()
                ], axis=1),
                'history/valid': history_valid,
                'future/xy': np.stack([future_x, future_y], axis=2),  # Shape: [128, 80, 2]
                'future/yaw': example['state/future/bbox_yaw'].numpy(),
                'future/valid': future_valid
            }

            # Save individual scene
            np.savez_compressed(output_path, **scene_data)
            print(f"\nProcessed data saved to {output_path}")
            print(f"Number of agents: {n_agents}")

        except Exception as e:
            print(f"Error processing record {i}: {str(e)}")
            continue

def verify_npz_data(npz_path):
    """Verify the processed NPZ file."""
    data = np.load(npz_path)

    print("\nData Verification:")
    print("==================")

    # Print shapes
    for key in data.files:
        print(f"{key}: {data[key].shape}")

        # Print sample values for key fields
        if key in ['history/xy', 'future/xy']:
            print(f"\nSample {key} values (first agent, first 5 timesteps):")
            print(data[key][0, :5])


def process_multiple_tfrecords(tfrecord_paths, output_path):
    """Process multiple TFRecord files and combine into one NPZ."""

    combined_data = None

    for tfrecord_path in tfrecord_paths:
        dataset = tf.data.TFRecordDataset(tfrecord_path)
        print(f"\nProcessing: {tfrecord_path}")

        for i, record_bytes in tqdm(enumerate(dataset)):
            try:
                example = parse_tfrecord(record_bytes)

                # Create scene data
                history_x = np.concatenate([
                    example['state/past/x'].numpy(),
                    example['state/current/x'].numpy()
                ], axis=1)

                history_y = np.concatenate([
                    example['state/past/y'].numpy(),
                    example['state/current/y'].numpy()
                ], axis=1)

                history_valid = np.concatenate([
                    example['state/past/valid'].numpy(),
                    example['state/current/valid'].numpy()
                ], axis=1)

                scene_data = {
                    'file_name': os.path.basename(tfrecord_path),
                    'scenario_id': example['scenario/id'].numpy()[0].decode(),
                    'agent_id': example['state/id'].numpy(),
                    'agent_type': example['state/type'].numpy(),
                    'agent_valid': example['state/current/valid'].numpy()[:, 0],
                    'width': example['state/current/width'].numpy()[:, 0],
                    'length': example['state/current/length'].numpy()[:, 0],
                    'history/xy': np.stack([history_x, history_y], axis=2),
                    'history/yaw': np.concatenate([
                        example['state/past/bbox_yaw'].numpy(),
                        example['state/current/bbox_yaw'].numpy()
                    ], axis=1),
                    'history/valid': history_valid,
                    'future/xy': np.stack([
                        example['state/future/x'].numpy(),
                        example['state/future/y'].numpy()
                    ], axis=2),
                    'future/yaw': example['state/future/bbox_yaw'].numpy(),
                    'future/valid': example['state/future/valid'].numpy()
                }

                # For the first scene, initialize the combined data
                if combined_data is None:
                    combined_data = {k: [] for k in scene_data.keys()}

                # Append data
                for key in combined_data:
                    combined_data[key].append(scene_data[key])

            except Exception as e:
                print(f"Error processing record {i} from {tfrecord_path}: {str(e)}")
                continue

    # Convert lists to arrays
    final_data = {}
    for key in combined_data:
        if combined_data[key]:  # Check if list is not empty
            try:
                final_data[key] = np.concatenate(combined_data[key], axis=0)
            except:
                final_data[key] = np.array(combined_data[key])

    # Save combined data
    np.savez_compressed(output_path, **final_data)
    print(f"\nSaved combined data to: {output_path}")

    # Print data shapes
    print("\nFinal data shapes:")
    for key, value in final_data.items():
        print(f"{key}: {value.shape}")

def verify_trajectory_continuity(history_xy, future_xy, history_valid, future_valid):
    """Verify that trajectories are continuous at the transition point."""
    # Check last point of history matches first point of future
    last_history = history_xy[history_valid][-1]
    first_future = future_xy[future_valid][0]

    diff = np.linalg.norm(last_history - first_future)
    print(f"Gap between history and future: {diff}")
    return diff < 1.0  # Threshold for acceptable difference

def process_trajectories(tfrecord_path, output_path, max_samples=None):
    """Process TFRecord file and save as NPZ with verification."""
    dataset = tf.data.TFRecordDataset(tfrecord_path)

    for i, record_bytes in tqdm(enumerate(dataset)):
        try:
            example = parse_tfrecord(record_bytes)

            # Combine past and current for history
            history_x = np.concatenate([
                example['state/past/x'].numpy(),
                example['state/current/x'].numpy()
            ], axis=1)

            history_y = np.concatenate([
                example['state/past/y'].numpy(),
                example['state/current/y'].numpy()
            ], axis=1)

            history_xy = np.stack([history_x, history_y], axis=2)
            future_xy = np.stack([
                example['state/future/x'].numpy(),
                example['state/future/y'].numpy()
            ], axis=2)

            # Print sample trajectories for verification
            if i == 0:  # First example
                print("\nSample trajectory verification:")
                print("History last 3 points:")
                print(history_xy[0, -3:])
                print("\nFuture first 3 points:")
                print(future_xy[0, :3])

                # Verify continuity
                history_valid = np.concatenate([
                    example['state/past/valid'].numpy(),
                    example['state/current/valid'].numpy()
                ], axis=1)
                future_valid = example['state/future/valid'].numpy()

                is_continuous = verify_trajectory_continuity(
                    history_xy[0], future_xy[0],
                    history_valid[0], future_valid[0]
                )
                print(f"Trajectory is continuous: {is_continuous}")

            # Create scene data dictionary
            scene_data = {
                'file_name': os.path.basename(tfrecord_path),
                'scenario_id': example['scenario/id'].numpy()[0].decode(),
                'agent_id': example['state/id'].numpy(),
                'agent_type': example['state/type'].numpy(),
                'agent_valid': example['state/current/valid'].numpy()[:, 0],
                'width': example['state/current/width'].numpy()[:, 0],
                'length': example['state/current/length'].numpy()[:, 0],
                'history/xy': history_xy,
                'history/yaw': np.concatenate([
                    example['state/past/bbox_yaw'].numpy(),
                    example['state/current/bbox_yaw'].numpy()
                ], axis=1),
                'history/valid': history_valid,
                'future/xy': future_xy,
                'future/yaw': example['state/future/bbox_yaw'].numpy(),
                'future/valid': future_valid
            }

            np.savez_compressed(output_path, **scene_data)

        except Exception as e:
            print(f"Error processing record {i}: {str(e)}")
            continue

def main():
    tfrecord_paths = [
        #"uncompressed_tf_example_training_training_tfexample.tfrecord-00000-of-01000",
        #"uncompressed_tf_example_training_training_tfexample.tfrecord-00010-of-01000",
         "uncompressed_tf_example_testing_testing_tfexample.tfrecord-00001-of-00150"
    ]

    output_dir = "processed_data_testing"
    os.makedirs(output_dir, exist_ok=True)

    for tfrecord_path in tfrecord_paths:
        output_path = os.path.join(output_dir,
                                 f"processed_{os.path.basename(tfrecord_path)}.npz")
        print(f"\nProcessing: {tfrecord_path}")
        process_trajectories(tfrecord_path, output_path)

if __name__ == "__main__":
    main()