# 官方環境驗證 (Gate 3)

驗證轉換後的 `model.pt` 在 SAI 官方環境中的行為。

**成功標準**：
- ✅ 機器人存活 > 50 步
- ✅ 有移動/踢球意圖
- ⚠️ 不要求高分數

## Step 1: 環境設置

In [None]:
# 確保 sai-rl 已安裝（如果 Cluster Library 未包含）
# %pip install sai-rl --quiet

In [None]:
import sys
import numpy as np
import torch

# 添加專案路徑
PROJECT_ROOT = "/Workspace/Users/adamlin@cheerstech.com.tw/.bundle/Booster_Soccer_plan/dev/files"
sys.path.insert(0, PROJECT_ROOT)

print(f"Project root: {PROJECT_ROOT}")
print(f"PyTorch version: {torch.__version__}")
print(f"NumPy version: {np.__version__}")

## Step 2: 載入模型和 Preprocessor

In [None]:
from submission.model import BoosterModel
from submission.preprocessor import Preprocessor

MODEL_PATH = f"{PROJECT_ROOT}/submission/model.pt"

# 載入模型
model = BoosterModel(MODEL_PATH)
preprocessor = Preprocessor()

print(f"✅ 模型載入成功: {MODEL_PATH}")

In [None]:
# 測試模型推理
test_obs = np.random.randn(1, 87).astype(np.float32)
test_action = model(test_obs)

print(f"Input shape:  {test_obs.shape}")
print(f"Output shape: {test_action.shape}")
print(f"Output range: [{test_action.min():.3f}, {test_action.max():.3f}]")
print("✅ 模型推理正常")

## Step 3: 連接 SAI 環境

In [None]:
from sai_rl import SAIClient

# 從 Databricks Secrets 獲取 API Key
API_KEY = dbutils.secrets.get(scope="sai-credentials", key="api-key")
print(f"✅ API Key 已載入（長度: {len(API_KEY)}）")

In [None]:
# 任務配置
TASKS = {
    "GoaliePenaltyKick": {
        "comp_id": "lower-t1-penalty-kick-goalie",
        "one_hot": np.array([1.0, 0.0, 0.0]),
    },
    "ObstaclePenaltyKick": {
        "comp_id": "lower-t1-penalty-kick-obstacle",
        "one_hot": np.array([0.0, 1.0, 0.0]),
    },
    "KickToTarget": {
        "comp_id": "lower-t1-kick-to-target",
        "one_hot": np.array([0.0, 0.0, 1.0]),
    },
}

print("任務配置：")
for name, cfg in TASKS.items():
    print(f"  {name}: {cfg['comp_id']}")

## Step 4: 定義測試函數

In [None]:
def test_single_task(
    task_name: str,
    model: BoosterModel,
    preprocessor: Preprocessor,
    api_key: str,
    num_episodes: int = 3,
    max_steps: int = 1000,
):
    """測試單一任務"""
    task_cfg = TASKS[task_name]
    comp_id = task_cfg["comp_id"]
    task_one_hot = task_cfg["one_hot"]
    
    print(f"\n{'='*50}")
    print(f"測試任務: {task_name}")
    print(f"comp_id: {comp_id}")
    print(f"{'='*50}")
    
    # 連接環境
    sai = SAIClient(comp_id=comp_id, api_key=api_key)
    env = sai.make_env()
    
    episode_rewards = []
    episode_lengths = []
    
    for ep in range(num_episodes):
        obs, info = env.reset()
        total_reward = 0
        step_count = 0
        
        for step in range(max_steps):
            # 如果 info 中沒有 task_index，手動注入（SAI 框架會自動提供）
            if "task_index" not in info:
                info["task_index"] = task_one_hot

            # 預處理觀察（使用與 SAI 框架一致的 2 參數簽名）
            processed_obs = preprocessor.modify_state(obs, info)
            
            # 模型推理
            action = model(processed_obs[np.newaxis, :]).numpy().squeeze()
            
            # 執行動作
            obs, reward, terminated, truncated, info = env.step(action)
            total_reward += reward
            step_count += 1
            
            if terminated or truncated:
                break
        
        episode_rewards.append(total_reward)
        episode_lengths.append(step_count)
        print(f"  Episode {ep+1}: reward={total_reward:.2f}, steps={step_count}")
    
    env.close()
    
    avg_reward = np.mean(episode_rewards)
    avg_length = np.mean(episode_lengths)
    
    print(f"\n  平均 reward: {avg_reward:.2f}")
    print(f"  平均 steps:  {avg_length:.1f}")
    
    # 基本成功標準
    success = avg_length > 50
    print(f"  Gate 3: {'✅ 通過' if success else '❌ 失敗'}（存活 > 50 步）")
    
    return {
        "task": task_name,
        "avg_reward": avg_reward,
        "avg_length": avg_length,
        "success": success,
        "episode_rewards": episode_rewards,
        "episode_lengths": episode_lengths,
    }

print("✅ 測試函數定義完成")

## Step 5: 執行測試

In [None]:
# 測試所有任務
results = []

for task_name in TASKS.keys():
    try:
        result = test_single_task(
            task_name=task_name,
            model=model,
            preprocessor=preprocessor,
            api_key=API_KEY,
            num_episodes=3,
            max_steps=1000,
        )
        results.append(result)
    except Exception as e:
        print(f"❌ 測試 {task_name} 時發生錯誤: {e}")
        import traceback
        traceback.print_exc()
        results.append({"task": task_name, "success": False, "error": str(e)})

## Step 6: 結果總結

In [None]:
print(f"\n{'='*60}")
print("Gate 3 驗證結果總結")
print(f"{'='*60}")

for r in results:
    status = "✅" if r.get("success") else "❌"
    if "error" in r:
        print(f"  {status} {r['task']}: 錯誤 - {r['error']}")
    else:
        print(f"  {status} {r['task']}: reward={r['avg_reward']:.2f}, steps={r['avg_length']:.1f}")

all_passed = all(r.get("success", False) for r in results)
print(f"\n總體結果: {'✅ Gate 3 通過' if all_passed else '⚠️ Gate 3 未完全通過'}")

if all_passed:
    print("\n下一步：可以進行 SAI 提交")
else:
    print("\n下一步：檢查失敗原因，可能需要：")
    print("  1. 檢查 Preprocessor 維度/四元數順序")
    print("  2. 使用 Feature Freeze 策略進行微調")
    print("  3. 檢查 docs/troubleshooting.md 尋找解決方案")

## （可選）單一任務測試

如果只想測試特定任務，取消註解下方程式碼：

In [None]:
# # 只測試 GoaliePenaltyKick
# result = test_single_task(
#     task_name="GoaliePenaltyKick",
#     model=model,
#     preprocessor=preprocessor,
#     api_key=API_KEY,
#     num_episodes=3,
# )