In [1]:
import pickle
import numpy as np
import os
from datetime import datetime

def load_observation_data(pkl_file_path):
    """Load observation data from pickle file."""
    with open(pkl_file_path, 'rb') as f:
        data = pickle.load(f)
    return data

def extract_observations_and_actions(data, save_path=None):
    """
    Extract observations and actions from loaded data and save them separately.
    
    Args:
        data: Loaded pickle data containing scaling coefficients and frames
        save_path: Path to save the extracted data (optional)
    
    Returns:
        observations: List of observation arrays for each frame
        actions: List of action arrays for each frame
    """
    
    scaling_coeffs = data['scaling_coefficients']
    frames = data['frames']
    
    observations = []
    actions = []
    
    print(f"Processing {len(frames)} frames...")
    print(f"Scaling coefficients: {scaling_coeffs}")
    
    for frame in frames:
        # Extract observation components
        raw_base_ang_vel = np.array(frame['raw_base_ang_vel'])
        raw_dof_pos = np.array(frame['raw_dof_pos']) 
        raw_dof_vel = np.array(frame['raw_dof_vel'])
        projected_gravity = np.array(frame['projected_gravity'])
        ref_motion_phase = frame['ref_motion_phase']
        raw_base_lin_vel = np.array(frame['raw_base_lin_vel'])
        
        # Scaled values (as used in the network)
        scaled_base_ang_vel = np.array(frame['scaled_base_ang_vel'])
        scaled_dof_pos = np.array(frame['scaled_dof_pos'])
        scaled_dof_vel = np.array(frame['scaled_dof_vel'])
        scaled_base_lin_vel = np.array(frame['scaled_base_lin_vel'])
        
        # Network action output
        network_action = np.array(frame['network_action'])
        
        # Create observation dictionary for this frame
        obs = {
            'frame_index': frame['index'],
            'raw_base_ang_vel': raw_base_ang_vel,
            'raw_base_lin_vel': raw_base_lin_vel,
            'raw_dof_pos': raw_dof_pos,
            'raw_dof_vel': raw_dof_vel,
            'scaled_base_ang_vel': scaled_base_ang_vel,
            'scaled_base_lin_vel': scaled_base_lin_vel,
            'scaled_dof_pos': scaled_dof_pos,
            'scaled_dof_vel': scaled_dof_vel,
            'projected_gravity': projected_gravity,
            'ref_motion_phase': ref_motion_phase
        }
        
        observations.append(obs)
        actions.append(network_action)
    
    # Save extracted data if path is provided
    if save_path:
        os.makedirs(save_path, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        # Save observations
        obs_filename = f"extracted_observations_{timestamp}.pkl"
        obs_path = os.path.join(save_path, obs_filename)
        with open(obs_path, 'wb') as f:
            pickle.dump(observations, f)
        print(f"Observations saved to: {obs_path}")
        
        # Save actions
        actions_filename = f"extracted_actions_{timestamp}.pkl"
        actions_path = os.path.join(save_path, actions_filename)
        with open(actions_path, 'wb') as f:
            pickle.dump(actions, f)
        print(f"Actions saved to: {actions_path}")
        
        # Save as numpy arrays for easier analysis
        actions_array = np.array(actions)
        np.save(os.path.join(save_path, f"actions_array_{timestamp}.npy"), actions_array)
        print(f"Actions array saved as numpy file: actions_array_{timestamp}.npy")
        
        # Save summary statistics
        summary = {
            'total_frames': len(frames),
            'action_shape': actions_array.shape,
            'action_stats': {
                'mean': np.mean(actions_array, axis=0).tolist(),
                'std': np.std(actions_array, axis=0).tolist(),
                'min': np.min(actions_array, axis=0).tolist(),
                'max': np.max(actions_array, axis=0).tolist()
            },
            'scaling_coefficients': scaling_coeffs
        }
        
        summary_filename = f"data_summary_{timestamp}.pkl"
        summary_path = os.path.join(save_path, summary_filename)
        with open(summary_path, 'wb') as f:
            pickle.dump(summary, f)
        print(f"Summary saved to: {summary_path}")
        
    return observations, actions

def analyze_data(observations, actions):
    """Analyze the extracted data and print statistics."""
    
    print("\n=== Data Analysis ===")
    print(f"Total frames: {len(observations)}")
    print(f"Action dimensions: {len(actions[0]) if actions else 0}")
    
    if actions:
        actions_array = np.array(actions)
        print(f"Actions shape: {actions_array.shape}")
        print(f"Action range: [{np.min(actions_array):.4f}, {np.max(actions_array):.4f}]")
        print(f"Action mean: {np.mean(actions_array):.4f}")
        print(f"Action std: {np.std(actions_array):.4f}")
        
        # Print per-joint statistics
        print("\n=== Per-Joint Action Statistics ===")
        joint_names = [
            "left_hip_pitch", "left_hip_roll", "left_hip_yaw", 
            "left_knee", "left_ankle_pitch", "left_ankle_roll",
            "right_hip_pitch", "right_hip_roll", "right_hip_yaw", 
            "right_knee", "right_ankle_pitch", "right_ankle_roll",
            "waist_yaw",
            "left_shoulder_pitch", "left_shoulder_roll", "left_shoulder_yaw", 
            "left_elbow", "left_wrist_roll",
            "right_shoulder_pitch", "right_shoulder_roll", "right_shoulder_yaw", 
            "right_elbow", "right_wrist_roll"
        ]
        
        for i, joint_name in enumerate(joint_names):
            if i < actions_array.shape[1]:
                joint_actions = actions_array[:, i]
                print(f"{joint_name:20s}: mean={np.mean(joint_actions):8.4f}, "
                      f"std={np.std(joint_actions):8.4f}, "
                      f"range=[{np.min(joint_actions):8.4f}, {np.max(joint_actions):8.4f}]")

def main():
    # Path to your pickle file
    pkl_file_path = "StraightPunch_observation_data_20250528_204217.pkl"
    
    # Output directory for extracted data
    output_dir = "extracted_data"
    
    try:
        # Load the data
        print(f"Loading data from: {pkl_file_path}")
        data = load_observation_data(pkl_file_path)
        
        # Extract observations and actions
        observations, actions = extract_observations_and_actions(data) #, output_dir)
        
        # Analyze the data
        analyze_data(observations, actions)
        
        print(f"\n=== Sample Data ===")
        if observations:
            print("First observation keys:", list(observations[0].keys()))
            print("First action:", actions[0][:5], "... (showing first 5 elements)")
            
    except FileNotFoundError:
        print(f"Error: File {pkl_file_path} not found!")
        print("Please make sure the file path is correct.")
    except Exception as e:
        print(f"Error processing data: {e}")

if __name__ == "__main__":
    main()

Loading data from: StraightPunch_observation_data_20250528_204217.pkl
Processing 1169 frames...
Scaling coefficients: {'dof_pos_scale': 1.0, 'dof_vel_scale': 0.05, 'ang_vel_scale': 0.25, 'lin_vel_scale': 2.0, 'action_scale': 0.25}

=== Data Analysis ===
Total frames: 1169
Action dimensions: 23
Actions shape: (1169, 23)
Action range: [-9.0213, 29.4426]
Action mean: 0.4963
Action std: 4.0439

=== Per-Joint Action Statistics ===
left_hip_pitch      : mean= -1.8382, std=  1.0749, range=[ -5.0275,   1.2909]
left_hip_roll       : mean=  1.7911, std=  1.7354, range=[ -2.1870,   6.8813]
left_hip_yaw        : mean=  0.2470, std=  1.1319, range=[ -4.9908,   3.6769]
left_knee           : mean=  1.2047, std=  0.8740, range=[ -0.7919,   3.4245]
left_ankle_pitch    : mean=  1.5853, std=  2.3774, range=[ -9.0213,  13.5245]
left_ankle_roll     : mean= -0.5583, std=  1.4666, range=[ -3.8571,   6.3442]
right_hip_pitch     : mean= -0.8738, std=  0.8740, range=[ -3.9458,   1.8976]
right_hip_roll      : me