In [1]:
import numpy as np
import gymnasium as gym
from sai_rl import SAIClient

In [2]:
# Import your simplify_obs function from your existing code
# If it's in the same file, just copy the definition here
def simplify_obs(obs, include_velocities=False):
    joint_positions = obs[0:9]
    if include_velocities:
        joint_velocities = obs[9:18]
        ball_pos = obs[18:21]
        club_pos = obs[21:24]
        club_quat = obs[24:28]
        hole_pos = obs[28:31]
    else:
        if len(obs) == 31:
            ball_pos = obs[18:21]
            club_pos = obs[21:24]
            club_quat = obs[24:28]
            hole_pos = obs[28:31]
        else:
            ball_pos = obs[9:12]
            club_pos = obs[12:15]
            club_quat = obs[15:19]
            hole_pos = obs[19:22]
        joint_velocities = np.zeros(9, dtype=np.float32)

    return {
        "joint_positions": joint_positions,
        "joint_velocities": joint_velocities,
        "ball_pos": ball_pos,
        "club_pos": club_pos,
        "club_quat": club_quat,
        "hole_pos": hole_pos,
    }


def test_env_randomization(env_id="franka-ml-hiring", num_resets=5):
    print(f"Testing {env_id} for starting position randomization...\n")

    sai = SAIClient(comp_id=env_id)

    # Training environment (NO rendering for speed)
    env = sai.make_env()

    positions = []
    for i in range(num_resets):
        # Use a random seed for each reset
        seed = np.random.randint(0, 1_000_000)
        obs, _ = env.reset(seed=seed)

        components = simplify_obs(obs, include_velocities=True)
        club_pos = components["club_pos"]

        positions.append(club_pos)
        print(f"Reset {i}: Seed={seed}, Club Pos={club_pos}")

    print("\nSummary of Club Positions:")
    for idx, pos in enumerate(positions):
        print(f"{idx}: {pos}")

    if all(np.allclose(positions[0], p, atol=1e-4) for p in positions[1:]):
        print("\n❌ Environment appears STATIC (club position is the same every reset).")
    else:
        print("\n✅ Environment randomization ACTIVE (club positions differ across resets).")

    env.close()


test_env_randomization()

Testing franka-ml-hiring for starting position randomization...



Reset 0: Seed=846618, Club Pos=[0.7   0.    0.135]
Reset 1: Seed=755839, Club Pos=[0.7   0.    0.135]
Reset 2: Seed=938500, Club Pos=[0.7   0.    0.135]
Reset 3: Seed=250805, Club Pos=[0.7   0.    0.135]
Reset 4: Seed=410679, Club Pos=[0.7   0.    0.135]

Summary of Club Positions:
0: [0.7   0.    0.135]
1: [0.7   0.    0.135]
2: [0.7   0.    0.135]
3: [0.7   0.    0.135]
4: [0.7   0.    0.135]

❌ Environment appears STATIC (club position is the same every reset).
