In [2]:
import numpy as np
import pickle
import os


def encode_angle(states):
    """Encode the angle at index 2 with its sine and cosine."""
    angles = states[:, 2]
    sin_angles = np.sin(angles)[:, np.newaxis]
    cos_angles = np.cos(angles)[:, np.newaxis]

    return np.concatenate(
        [states[:, :2], sin_angles, cos_angles, states[:, 3:]], axis=1
    )


def load_data(file_path):
    """Load data from a single pickle file."""
    with open(file_path, "rb") as file:
        prev_states, actions, next_states = pickle.load(file)
        prev_states, actions, next_states = (
            np.array(prev_states),
            np.array(actions),
            np.array(next_states),
        )
    return prev_states, actions, next_states


def aggregate_data(file_paths, encode_angle_bool):
    """Aggregate data from multiple files."""
    all_inputs = []
    all_next_states = []

    for file_path in file_paths:
        prev_states, actions, next_states = load_data(file_path)
        if encode_angle_bool:
            prev_states = encode_angle(prev_states)
            next_states = encode_angle(next_states)

        inputs = np.hstack((prev_states, actions))
        all_inputs.append(inputs)
        all_next_states.append(next_states)

    all_inputs = np.concatenate(all_inputs, axis=0)
    all_next_states = np.concatenate(all_next_states, axis=0)

    return all_inputs, all_next_states


def calculate_statistics(all_inputs, all_next_states):
    """Calculate mean and std for all states and inputs."""
    mean_inputs = np.mean(all_inputs, axis=0)
    std_inputs = np.std(all_inputs, axis=0)

    mean_next_states = np.mean(all_next_states, axis=0)
    std_next_states = np.std(all_next_states, axis=0)

    # round to 3 decimal places
    mean_inputs = np.round(mean_inputs, 3)
    std_inputs = np.round(std_inputs, 3)
    mean_next_states = np.round(mean_next_states, 3)
    std_next_states = np.round(std_next_states, 3)

    return mean_inputs, std_inputs, mean_next_states, std_next_states

# options
file_paths = [
    # "/home/bhoffman/Documents/MT FS24/simulation_transfer/data/recordings_spot_v0/dataset_learn_jax_20240815-151559_20240816-105027.pickle",
    # "/home/bhoffman/Documents/MT FS24/simulation_transfer/data/recordings_spot_v0/dataset_learn_jax_20240819-141455_20240820-101740.pickle",
    # "/home/bhoffman/Documents/MT FS24/simulation_transfer/data/recordings_spot_v0/dataset_learn_jax_20240819-142443_20240820-101938.pickle",
    # "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/recordings_spot_v1/dataset_learn_jax_test20240830-111841_v1_1.pickle",
    # "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/recordings_spot_v1/dataset_learn_jax_test20240830-112255_v1_3.pickle",
    # "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/recordings_spot_v2/dataset_learn_jax_test20240903-132044_v2_1.pickle",
    # "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/recordings_spot_v2/dataset_learn_jax_test20240903-132303_v2_2.pickle",
    # "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/recordings_spot_v2/dataset_learn_jax_test20240903-132514_v2_3.pickle",
    # "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/recordings_spot_v3/dataset_learn_jax_test20240904-153813_v3_1.pickle",
    # "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/recordings_spot_v3/dataset_learn_jax_test20240904-154043_v3_2.pickle",
    # "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/recordings_spot_v3/dataset_learn_jax_test20240904-154353_v3_3.pickle",
    # "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/recordings_spot_v3/dataset_learn_jax_test20240904-155015_v3_5.pickle",
    "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/recordings_spot_v0/dataset_learn_jax_20240815-151559_v0_1.pickle",
    "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/recordings_spot_v0/dataset_learn_jax_20240819-142443_v0_3.pickle",
    "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/recordings_spot_v1/dataset_learn_jax_test20240830-111841_v1_1.pickle",
    "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/recordings_spot_v1/dataset_learn_jax_test20240830-112255_v1_3.pickle",
    "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/recordings_spot_v2/dataset_learn_jax_test20240903-132044_v2_1.pickle",
    "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/recordings_spot_v2/dataset_learn_jax_test20240903-132303_v2_2.pickle",
    "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/recordings_spot_v2/dataset_learn_jax_test20240903-132514_v2_3.pickle",
    "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/recordings_spot_v3/dataset_learn_jax_test20240904-154043_v3_2.pickle",
    "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/recordings_spot_v3/dataset_learn_jax_test20240904-154353_v3_3.pickle",
    "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/recordings_spot_v3/dataset_learn_jax_test20240904-155015_v3_4.pickle",
    "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/recordings_spot_v4/dataset_learn_jax_test20240909-142029_v4_1.pickle",
    "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/recordings_spot_v4/dataset_learn_jax_test20240909-142230_v4_2.pickle",
    # "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/test_data_spot/dataset_learn_jax_20240819-141455_v0_2.pickle",
    "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/test_data_spot/dataset_learn_jax_test20240830-112105_v1_2.pickle",
    "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/test_data_spot/dataset_learn_jax_test20240904-153813_v3_1.pickle",
    "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/test_data_spot/dataset_learn_jax_test20240909-142535_v4_3.pickle",
    "/home/bhoffman/Documents/MT_FS24/simulation_transfer/data/test_data_spot/dataset_learn_jax_test20240909-142945_v4_4.pickle",
]
encode_angle_bool = True
all_inputs, all_next_states = aggregate_data(file_paths, encode_angle_bool)

# Calculate statistics
mean_inputs, std_inputs, mean_next_states, std_next_states = calculate_statistics(
    all_inputs, all_next_states
)

print("Angle encoding: ", encode_angle_bool)
print("Dims inputs & outputs:", all_inputs.shape[1], all_next_states.shape[1])
print("Number of samples:", all_inputs.shape[0])
print("Mean inputs:", list(mean_inputs))
print("Std inputs:", list(std_inputs))
print("Mean next states:", list(mean_next_states))
print("Std next states:", list(std_next_states))

Angle encoding:  True
Dims inputs & outputs: 19 13
Number of samples: 8753
Mean inputs: [0.887, 0.148, 0.058, 0.978, 0.01, 0.003, -0.0, 1.911, 0.198, 0.475, 0.011, 0.003, -0.002, 0.02, -0.001, -0.003, 0.028, -0.002, -0.007]
Std inputs: [0.994, 0.332, 0.194, 0.052, 0.436, 0.191, 0.249, 1.021, 0.39, 0.2, 0.459, 0.321, 0.184, 0.438, 0.157, 0.232, 0.194, 0.207, 0.192]
Mean next states: [0.889, 0.148, 0.058, 0.978, 0.01, 0.003, -0.0, 1.913, 0.198, 0.475, 0.011, 0.003, -0.002]
Std next states: [0.994, 0.332, 0.194, 0.052, 0.436, 0.191, 0.249, 1.021, 0.39, 0.2, 0.459, 0.321, 0.184]
