# Task Index 維度驗證

## 目的
確認官方環境 `info['task_index']` 的真實維度，解決 83 vs 87 維度差異問題。

## 背景
- JAX preprocessor 當前輸出 **83 維**
- 官方 `main.py` 使用 `n_features=87`
- 假設：`info['task_index']` 可能是 **7 維**（非 3 維 one-hot）

In [None]:
# Cell 1: 依賴檢查
# 依賴已透過 Cluster Library (requirements.txt) 安裝
# 如需手動安裝，請使用: %pip install sai-mujoco gymnasium "numpy<2" --quiet

import sys
print(f"Python: {sys.version}")

try:
    import numpy as np
    print(f"NumPy: {np.__version__} {'✅' if np.__version__.startswith('1.') else '⚠️ NumPy 2.x 可能導致 ABI 問題'}")
except ImportError:
    print("❌ NumPy not installed")

try:
    import gymnasium
    print(f"Gymnasium: {gymnasium.__version__} ✅")
except ImportError:
    print("❌ Gymnasium not installed - run: %pip install gymnasium 'numpy<2'")

In [None]:
# Cell 2: 導入模組
import gymnasium as gym
import numpy as np
import sys
import os

# 設置路徑以導入 booster_soccer_showdown
project_root = os.path.abspath(os.path.join(os.getcwd(), '../..'))
sys.path.insert(0, project_root)

print(f"Project root: {project_root}")

In [None]:
# Cell 3: 測試三個官方環境的 task_index
envs_to_test = [
    "LowerT1GoaliePenaltyKick-v0",
    "LowerT1ObstaclePenaltyKick-v0",
    "LowerT1KickToTarget-v0",
]

results = {}

for env_name in envs_to_test:
    print(f"\n{'='*50}")
    print(f"Testing: {env_name}")
    print(f"{'='*50}")
    
    try:
        env = gym.make(env_name)
        obs, info = env.reset()
        
        # 記錄 obs 維度
        print(f"Raw obs shape: {np.array(obs).shape}")
        
        # 檢查 task_index
        if 'task_index' in info:
            task_index = np.array(info['task_index'])
            results[env_name] = {
                'obs_shape': np.array(obs).shape,
                'task_index_shape': task_index.shape,
                'task_index_dtype': str(task_index.dtype),
                'task_index_value': task_index.tolist(),
            }
            print(f"task_index.shape: {task_index.shape}")
            print(f"task_index.dtype: {task_index.dtype}")
            print(f"task_index value: {task_index}")
        else:
            print("⚠️ task_index NOT FOUND in info dict")
            print(f"Available info keys: {list(info.keys())}")
            results[env_name] = {'task_index': 'NOT_FOUND'}
        
        env.close()
        
    except Exception as e:
        print(f"❌ Error: {e}")
        results[env_name] = {'error': str(e)}

print(f"\n\n{'='*50}")
print("Summary:")
print(f"{'='*50}")
for env_name, data in results.items():
    print(f"\n{env_name}:")
    for k, v in data.items():
        print(f"  {k}: {v}")

In [None]:
# Cell 4: 使用官方 Preprocessor 驗證最終維度
try:
    # 導入官方 Preprocessor
    from booster_soccer_showdown.training_scripts.main import Preprocessor
    
    preprocessor = Preprocessor()
    
    # 測試一個環境
    env = gym.make("LowerT1GoaliePenaltyKick-v0")
    obs, info = env.reset()
    
    # 調用 modify_state
    processed_obs = preprocessor.modify_state(obs, info)
    
    print(f"\n官方 Preprocessor 輸出：")
    print(f"  Input obs shape: {np.array(obs).shape}")
    print(f"  Output obs shape: {processed_obs.shape}")
    print(f"  Final dimension: {processed_obs.shape[-1]}")
    
    env.close()
    
except Exception as e:
    print(f"❌ Error importing official Preprocessor: {e}")
    print("Note: This cell requires the official SAI environment to be properly installed")

In [None]:
# Cell 5: 詳細列出 info dict 所有欄位維度
env = gym.make("LowerT1GoaliePenaltyKick-v0")
obs, info = env.reset()

print("Info dict 欄位維度：")
print(f"{'='*60}")

total_dims = 0
for key, value in sorted(info.items()):
    arr = np.array(value)
    dims = arr.size
    total_dims += dims
    print(f"{key:40s} | shape: {str(arr.shape):15s} | dims: {dims}")

print(f"{'='*60}")
print(f"Total info dimensions: {total_dims}")

# 加上 raw obs 的維度
obs_arr = np.array(obs)
print(f"Raw obs dimensions: {obs_arr.size}")
print(f"\nExpected final obs: raw_obs({obs_arr.size}) + info_subset + task_index")

env.close()

## 結論

運行此 notebook 後，請確認：

1. `task_index` 維度是 **3** 還是 **7**？
2. 官方 Preprocessor 的最終輸出是 **83** 還是 **87** 維？
3. 如果是 83 維，則 `main.py` 中的 `n_features=87` 是錯誤的硬編碼值

### 下一步行動
- 如果確認是 83 維：更新 `preprocessor_jax.py` 註解，繼續開發
- 如果確認是 87 維：找出缺少的 4 維來源