# Process Steps Data for RL Preference Learning

This notebook processes a file containing RL episodes data and generates validation segments and rewards for preference learning.

## Import Required Libraries

In [None]:
import torch
from pref_rl.utils.pref import Sampler
import os

## Set Constants

In [None]:
OBS_DIM = 78
ACT_DIM = 12

## Set the Path to the Raw Data Files

Update the path below to point to the directory containing your SB3 data files.

In [None]:
# Set the path to your data dir
data_dir = '../outputs/2025-07-02/00-27-27'

# Verify the dir exists
if not os.path.exists(data_dir):
    raise FileNotFoundError(f"The directory {data_dir} does not exist")
print(f"Directory found: {data_dir}")

## Load and Examine the Steps Data

In [None]:
# Load the steps data from multiple SB3 model files
print(f"Loading episode data from {data_dir}...")

import glob
from stable_baselines3.common.save_util import load_from_zip_file

# Find all model files in the directory
model_files = glob.glob(os.path.join(data_dir, "final_model_*.zip"))
print(f"Found {len(model_files)} model files")

all_episodes = []
for model_file in model_files:
    print(f"Loading from {os.path.basename(model_file)}...")
    try:
        data, _, _ = load_from_zip_file(model_file, load_data=True)
        print(data.keys())
        done_eps = data.get('episode_buffer', data.get('pref_ppo_data', {}).get('buffer')).done_eps  # Handle LoggingPPO and PrefPPO
        all_episodes.extend(list(done_eps))
        print(f"  Loaded {len(done_eps)} episodes")
    except KeyError as e:
        print(f"  Warning: Could not find expected data structure in {model_file}: {e}")
    except Exception as e:
        print(f"  Error loading {model_file}: {e}")

if not all_episodes:
    raise ValueError("No episodes were loaded from any model files")

eps = torch.stack(all_episodes)
print(f"\nTotal loaded episodes: {len(eps)}")
print(f"Stacked episodes shape: {eps.shape}")
print(f"Episode {0:2} shape: {eps[0].shape}")
print(f"Episode {len(eps)-1:2} shape: {eps[-1].shape}")

In [None]:
# Split episodes into observations, actions, and rewards
print("Splitting episodes into (obs, actions, rewards)...")

# Split obs, actions, rewards based on dimensions
obs = eps[:, :, :OBS_DIM]
actions = eps[:, :, OBS_DIM:OBS_DIM + ACT_DIM]
rewards = eps[:, :, -1]

print(f"Observations shape: {obs.shape}")
print(f"Actions shape: {actions.shape}")
print(f"Rewards shape: {rewards.shape}")

steps = {
    'obs': obs,
    'actions': actions,
    'rewards': rewards
}

## Initialize the Sampler

Initialize the `Sampler` with the appropriate dimensions from the data.

In [None]:
segment_length = 50

print(f"Initializing Sampler with parameters:")
print(f"  segment_length: {segment_length}")
print(f"  obs_dim: {OBS_DIM}")
print(f"  action_dim: {ACT_DIM}")

sampler = Sampler(segment_length, OBS_DIM, ACT_DIM)

## Sample Segments

Use the `sample_segments` method to sample segments from the reshaped episodes.

In [None]:
print("Sampling segments...")
num_segments = 8000
print(f"  Number of segments: {num_segments}")
print(f"  Sampling method: uniform")

sa, r, _ = sampler.sample_segments(eps, num_segments, 'uniform', None, stratified=True, compute_uniform_metrics=False)
print(f"  Segments shape: {sa.shape}")
print(f"  Rewards shape: {r.shape}")

## Save Segments and Rewards

In [None]:
data_file = '../data/validation_data.pkl'

print(f"Saving segments and rewards to {data_file}...")
torch.save([sa.contiguous(), r.contiguous()], data_file)

print("Processing completed successfully!")

## Optional: Verify Saved Files

In [None]:
# Check that the files were created and show their sizes
files_to_check = [data_file]
for file in files_to_check:
    if os.path.exists(file):
        size_mb = os.path.getsize(file) / (1024 * 1024)
        print(f"{file} - Size: {size_mb:.2f} MB")
    else:
        print(f"{file} not found")