# 01 â€” Data Exploration

Visualize random rollout data collected from CarRacing-v3.

**Prerequisites:** Run `scripts/collect_data.py` first to generate `.npz` files.

In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

sys.path.insert(0, str(Path.cwd().parent))
from src.config import Config

config = Config()
data_dir = Path.cwd().parent / config.data.data_dir
npz_files = sorted(data_dir.glob("*.npz"))
print(f"Found {len(npz_files)} episode files")

In [None]:
# Load a sample episode
episode = np.load(npz_files[0])
obs = episode["observations"]
actions = episode["actions"]
rewards = episode["rewards"]
print(f"Observations: {obs.shape}, dtype={obs.dtype}")
print(f"Actions: {actions.shape}")
print(f"Rewards: {rewards.shape}, total={rewards.sum():.1f}")

In [None]:
# Visualize sample frames
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
indices = np.linspace(0, len(obs) - 1, 10, dtype=int)
for ax, idx in zip(axes.flat, indices):
    ax.imshow(obs[idx])
    ax.set_title(f"t={idx}")
    ax.axis("off")
plt.suptitle("Sample Frames from Episode 0")
plt.tight_layout()
plt.show()

In [None]:
# Episode length and reward distribution
lengths = []
total_rewards = []
for f in npz_files[:50]:  # sample first 50
    ep = np.load(f)
    lengths.append(len(ep["rewards"]))
    total_rewards.append(ep["rewards"].sum())

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.hist(lengths, bins=20)
ax1.set_title("Episode Lengths")
ax1.set_xlabel("Steps")
ax2.hist(total_rewards, bins=20)
ax2.set_title("Total Rewards (Random Policy)")
ax2.set_xlabel("Reward")
plt.tight_layout()
plt.show()
print(f"Avg length: {np.mean(lengths):.0f}, Avg reward: {np.mean(total_rewards):.1f}")

In [None]:
# Action distribution
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
labels = ["Steering", "Gas", "Brake"]
for i, (ax, label) in enumerate(zip(axes, labels)):
    ax.hist(actions[:, i], bins=30, alpha=0.7)
    ax.set_title(label)
    ax.set_xlabel("Value")
plt.suptitle("Action Distribution (Episode 0)")
plt.tight_layout()
plt.show()