# JAX → PyTorch 轉換 Notebook

將 SAC checkpoint 轉換為 PyTorch TorchScript 格式

**Checkpoint 路徑**: `/Workspace/Users/adamlin@cheerstech.com.tw/.bundle/Booster_Soccer_plan/dev/files/exp/sac_mjx/checkpoints/u4oasfsj/final_checkpoint.pkl`

## Step 1: 驗證 Checkpoint 結構

In [None]:
import pickle
import numpy as np

# Checkpoint 路徑
CHECKPOINT_PATH = "/Workspace/Users/adamlin@cheerstech.com.tw/.bundle/Booster_Soccer_plan/dev/files/exp/sac_mjx/checkpoints/u4oasfsj/final_checkpoint.pkl"

# 載入 checkpoint
with open(CHECKPOINT_PATH, "rb") as f:
    ckpt = pickle.load(f)

print("Checkpoint top-level keys:", list(ckpt.keys()))

In [None]:
# 檢查 actor 參數路徑
# 預期路徑: ckpt["agent"]["network"]["params"]["modules_actor"]

try:
    params = ckpt["agent"]["network"]["params"]["modules_actor"]
    print("✅ 標準路徑存在")
    print("Actor params keys:", list(params.keys()))
except KeyError as e:
    print(f"❌ 標準路徑不存在: {e}")
    print("\n嘗試探索 checkpoint 結構...")
    
    def explore_dict(d, prefix=""):
        for k, v in d.items():
            if isinstance(v, dict):
                print(f"{prefix}{k}/")
                if len(prefix) < 20:  # 限制深度
                    explore_dict(v, prefix + "  ")
            else:
                shape = getattr(v, 'shape', 'scalar')
                print(f"{prefix}{k}: {type(v).__name__} {shape}")
    
    explore_dict(ckpt)

In [None]:
# 驗證網路結構
actor_net = params["actor_net"]
print("actor_net keys:", list(actor_net.keys()))

# 提取維度信息
dense_layers = sorted([k for k in actor_net.keys() if "Dense" in k], key=lambda x: int(x.split("_")[-1]))
print(f"\nDense layers: {dense_layers}")

input_dim = actor_net[dense_layers[0]]["kernel"].shape[0]
hidden_dims = [actor_net[d]["bias"].shape[0] for d in dense_layers]
action_dim = params["mean_net"]["bias"].shape[0]

print(f"\n=== 網路結構 ===")
print(f"Input dim:  {input_dim} (預期: 87)")
print(f"Hidden dims: {hidden_dims} (預期: [256, 256, 256])")
print(f"Action dim: {action_dim} (預期: 12)")

# 驗證
assert input_dim == 87, f"❌ Input dim mismatch: {input_dim} != 87"
assert action_dim == 12, f"❌ Action dim mismatch: {action_dim} != 12"
print("\n✅ 結構驗證通過！")

## Step 2: 定義 PyTorch 模型

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TorchMLP(nn.Module):
    """多層感知機（與 jax2torch.py 一致）"""
    def __init__(self, input_dim, hidden_layers, activate_final=True, layer_norm=True):
        super().__init__()
        self.hidden_layers = nn.ModuleList()
        self.layer_norms = nn.ModuleList()
        self.activate_final = activate_final
        self.layer_norm = layer_norm

        last_dim = input_dim
        for h in hidden_layers:
            self.hidden_layers.append(nn.Linear(last_dim, h))
            if layer_norm:
                self.layer_norms.append(nn.LayerNorm(h, eps=1e-6))
            last_dim = h

    def forward(self, x):
        for i, layer in enumerate(self.hidden_layers):
            x = layer(x)
            if i + 1 < len(self.hidden_layers) or self.activate_final:
                x = F.relu(x)
                if self.layer_norm:
                    x = self.layer_norms[i](x)
        return x


class TorchGCActor(nn.Module):
    """Gaussian-Constrained Actor（與 jax2torch.py 一致）"""
    def __init__(self, obs_dim, hidden_layers, action_dim):
        super().__init__()
        self.actor_net = TorchMLP(obs_dim, hidden_layers, activate_final=True, layer_norm=True)
        last_hidden = hidden_layers[-1]
        self.mean_net = nn.Linear(last_hidden, action_dim)
        self.log_std_net = nn.Linear(last_hidden, action_dim)
    
    def forward(self, obs, temperature=1.0):
        feat = self.actor_net(obs)
        mean = self.mean_net(feat)
        log_std = self.log_std_net(feat)
        log_std = torch.clamp(log_std, -5.0, 2.0)
        std = torch.exp(log_std) * temperature
        return mean, std  # BoosterModel 只用 mean (index 0)

print("✅ PyTorch 模型類別定義完成")

## Step 3: 權重轉換

In [None]:
def load_dense(torch_layer, jax_layer):
    """複製 Dense 層權重（注意轉置）"""
    # JAX kernel: (in, out) → PyTorch weight: (out, in)
    # 使用 np.asarray() 處理 JAX array 的 read-only 問題
    kernel = np.asarray(jax_layer["kernel"])
    bias = np.asarray(jax_layer["bias"])
    torch_layer.weight.data = torch.tensor(kernel.T.copy())
    torch_layer.bias.data = torch.tensor(bias.copy())

def load_layernorm(torch_ln, jax_ln):
    """複製 LayerNorm 權重"""
    scale = np.asarray(jax_ln["scale"])
    bias = np.asarray(jax_ln["bias"])
    torch_ln.weight.data = torch.tensor(scale.copy())
    torch_ln.bias.data = torch.tensor(bias.copy())

# 創建 PyTorch 模型
torch_model = TorchGCActor(
    obs_dim=input_dim,
    hidden_layers=hidden_dims,
    action_dim=action_dim,
)

# 複製 Dense 層權重
for i, dname in enumerate(dense_layers):
    load_dense(torch_model.actor_net.hidden_layers[i], actor_net[dname])
    print(f"  載入 {dname} → hidden_layers[{i}]")

# 複製 LayerNorm 權重
ln_layers = sorted([k for k in actor_net.keys() if "LayerNorm" in k], key=lambda x: int(x.split("_")[-1]))
for i, lname in enumerate(ln_layers):
    load_layernorm(torch_model.actor_net.layer_norms[i], actor_net[lname])
    print(f"  載入 {lname} → layer_norms[{i}]")

# 複製輸出層權重
load_dense(torch_model.mean_net, params["mean_net"])
print("  載入 mean_net")
load_dense(torch_model.log_std_net, params["log_std_net"])
print("  載入 log_std_net")

print("\n✅ 權重轉換完成！")

## Step 4: 驗證輸出

In [None]:
# 測試 forward pass
torch_model.eval()
test_obs = torch.randn(1, input_dim)

with torch.no_grad():
    mean, std = torch_model(test_obs)

print(f"Test input shape:  {test_obs.shape}")
print(f"Mean output shape: {mean.shape} (預期: [1, 12])")
print(f"Std output shape:  {std.shape} (預期: [1, 12])")
print(f"\nMean values: {mean.numpy().flatten()}")
print(f"Std values:  {std.numpy().flatten()}")

# 驗證動作範圍（tanh 後應在 [-1, 1]）
print(f"\nMean range: [{mean.min():.3f}, {mean.max():.3f}]")

## Step 5: 保存 TorchScript 模型

In [None]:
# TorchScript 編譯
scripted_model = torch.jit.trace(torch_model, (torch.randn(1, input_dim),))

# 保存路徑（Databricks 上）
OUTPUT_PATH = "/Workspace/Users/adamlin@cheerstech.com.tw/.bundle/Booster_Soccer_plan/dev/files/submission/model.pt"

# 確保目錄存在
import os
os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)

# 保存
torch.jit.save(scripted_model, OUTPUT_PATH)
print(f"✅ TorchScript 模型已保存到: {OUTPUT_PATH}")

In [None]:
# 驗證載入
loaded_model = torch.jit.load(OUTPUT_PATH)
loaded_model.eval()

with torch.no_grad():
    loaded_mean, loaded_std = loaded_model(test_obs)

# 比較輸出
mean_diff = (mean - loaded_mean).abs().max().item()
print(f"Loaded model output diff: {mean_diff:.2e}")
assert mean_diff < 1e-5, "❌ 載入後輸出不一致！"
print("✅ 模型載入驗證通過！")

## Step 6: 下載到本地（可選）

在 Databricks 中，可以使用以下方式下載文件到本地：

1. **Databricks CLI**:
```bash
databricks workspace export /Workspace/Users/adamlin@cheerstech.com.tw/.bundle/Booster_Soccer_plan/dev/files/submission/model.pt ./submission/model.pt --overwrite
```

2. **手動下載**: 在 Workspace UI 中右鍵點擊文件 → Download

3. **DBFS 複製**（如果需要）:
```python
dbutils.fs.cp("file:" + OUTPUT_PATH, "dbfs:/FileStore/model.pt")
# 然後從 https://<workspace>/files/model.pt 下載
```

In [None]:
# 如果在 Databricks 環境中，可以複製到 DBFS 供下載
try:
    # 複製到 DBFS FileStore（可通過 URL 下載）
    dbutils.fs.cp(f"file:{OUTPUT_PATH}", "dbfs:/FileStore/submission/model.pt")
    print("✅ 已複製到 DBFS FileStore")
    print("下載 URL: https://<your-workspace>/files/submission/model.pt")
except NameError:
    print("⚠️ 不在 Databricks 環境中，跳過 DBFS 複製")

## 總結

轉換完成！下一步：
1. 下載 `model.pt` 到本地 `submission/` 目錄
2. 複製 `booster_soccer_showdown/imitation_learning/submission/model.py` 到 `submission/`
3. 在官方環境測試