In [1]:
import gymnasium as gym
from gymnasium.envs.registration import register
import mujoco
from tqdm import tqdm
import torch
import math

from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CallbackList
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecMonitor
from stable_baselines3.common.monitor import Monitor


In [2]:
register(
    id="galaxea_r1Pro",
    entry_point="galaxea_r1Pro:Galaxea_r1Pro",
)


In [3]:
def make_env(env_id, rank=0, seed=0):
    """
    env_id: Gymnasium 环境名
    rank: 每个环境编号（用于seed区分）
    seed: 基础随机种子
    """
    def _init():
        register(
            id="galaxea_r1Pro",
            entry_point="galaxea_r1Pro:Galaxea_r1Pro",
        )
        env = gym.make(env_id)
        env.reset(seed=seed + rank)
        env = Monitor(env)
        return env
    return _init

ENV_ID = "galaxea_r1Pro"
NUM_ENVS = 8
SEED = 0

## 创建并行环境
train_env = SubprocVecEnv([make_env(ENV_ID, i, SEED) for i in range(NUM_ENVS)])
train_env = VecMonitor(train_env)  # 配合VecMonitor使用

## 创建单个环境
# train_env = gym.make(ENV_ID)
# obs, _ = train_env.reset()
print(f"obs space: {train_env.observation_space.shape}, action space: {train_env.action_space.shape}")

obs space: (706,), action space: (24,)




In [4]:
print(train_env.num_envs)

8


In [5]:

log_dir = "./tb_log/"

total_timesteps = 2400000  # 总训练步数

# 评估环境
eval_env = gym.make(ENV_ID)   # 你的环境

eval_callback = EvalCallback(
    eval_env,
    best_model_save_path=log_dir+"best_model4",  # 自动保存最优模型的目录
    log_path=log_dir,                        # 保存评估日志
    eval_freq=10000,                          # 每 1 万步评估一次
    n_eval_episodes=5,                         # 每次评估 5 个 episode
    deterministic=True,                        # 评估时用确定性策略
    render=False
)


In [6]:
class RewardInfoCallback(BaseCallback):
    """
    在 TensorBoard 中记录 info 字典中的各个 reward 项。
    支持 vectorized environments。
    """
    def __init__(self, verbose=0):
        super().__init__(verbose)

    def _on_step(self) -> bool:
        infos = self.locals.get("infos", [])
        if not infos:
            return True

        # 对 vectorized env 平均 reward
        reward_sums = {}
        for info in infos:
            for key, value in info.items():
                if key.startswith("reward_"):
                    reward_sums[key] = reward_sums.get(key, 0.0) + float(value)

        for key, total in reward_sums.items():
            mean_value = total / len(infos)
            self.logger.record(f"reward/{key}", mean_value)

        return True
    
reward_info_cb = RewardInfoCallback()
    

In [7]:

# 自定义SAC网络结构
# obs space: (706,), action space: (24,)
policy_kwargs = dict(
    net_arch=dict(pi=[256, 256], qf=[256, 256]), # 每个隐藏层的神经元数量，也可以写成 [400, 300] 等
    activation_fn=torch.nn.ReLU  # 激活函数，可改为 torch.nn.Tanh
)

def warm_sin_lr(progress_remaining: float) -> float:
    """
    progress_remaining: 1 -> 0
    假设总共训练T步：RewardInfoCallback
      - 前10% steps: 线性从 1e-5 升到 3e-4 (warm-up)
      - 之后: 按正弦方式从 3e-4 降到 1e-5
    """
    lr_min = 1e-4   
    lr_max = 1e-3
    warm_ratio = 0.01  # 10% warm-up

    # progress_remaining=1 -> step=0; progress_remaining=0 -> step=end
    progress_done = 1.0 - progress_remaining

    if progress_done < warm_ratio:
        # warm-up: 线性上升
        return lr_min + (lr_max - lr_min) * (progress_done / warm_ratio)
    else:
        # sin下降：这里重新归一化到[0,1]
        x = (progress_done - warm_ratio) / (1 - warm_ratio)
        return lr_min + (lr_max - lr_min) * math.sin((1 - x) * math.pi / 2)

model = SAC(
    "MlpPolicy",
    train_env,
    verbose=1,
    learning_rate=warm_sin_lr,
    buffer_size=100_000,      # 经验回放缓冲区大小. 这个参数PPO没有
    learning_starts=1000, 
    batch_size=256,             # 默认256
    tau=0.005,                  # 软更新系数
    gamma=0.99,                 # 折扣因子
    train_freq=1,               # 每步都训练，采集多少个环境步的数据后训练一次
    gradient_steps=1,           # 对replayBuffer中读取到的batch，进行多少次梯度下降更新
    tensorboard_log=log_dir,   # 日志目录
    policy_kwargs=policy_kwargs,  # 将自定义结构传进去
)


# 训练模型, total_timesteps自行调整
model.learn(total_timesteps=total_timesteps, 
            tb_log_name="sac", 
            progress_bar=True,
            callback=[eval_callback, reward_info_cb])
# 保存模型
model.save("galaxea_sac_lr_forward")
# model.save_replay_buffer("my_buffer.pkl")

Using cuda device
Logging to ./tb_log/sac_22


----------------------------------
| reward/            |           |
|    reward_contact  | -1.25     |
|    reward_ctrl     | -16.8     |
|    reward_forward  | -2.16     |
|    reward_survive  | 3.75      |
| rollout/           |           |
|    ep_len_mean     | 56.8      |
|    ep_rew_mean     | -1.45e+03 |
| time/              |           |
|    episodes        | 4         |
|    fps             | 6240      |
|    time_elapsed    | 0         |
|    total_timesteps | 472       |
----------------------------------
----------------------------------
| reward/            |           |
|    reward_contact  | 0         |
|    reward_ctrl     | -22.1     |
|    reward_forward  | -3.39     |
|    reward_survive  | 4.38      |
| rollout/           |           |
|    ep_len_mean     | 63.4      |
|    ep_rew_mean     | -1.61e+03 |
| time/              |           |
|    episodes        | 8         |
|    fps             | 6046      |
|    time_elapsed    | 0         |
|    total_timesteps



----------------------------------
| reward/            |           |
|    reward_contact  | 0         |
|    reward_ctrl     | -42.6     |
|    reward_forward  | -2.75     |
|    reward_survive  | 4.38      |
| rollout/           |           |
|    ep_len_mean     | 62.8      |
|    ep_rew_mean     | -1.68e+03 |
| time/              |           |
|    episodes        | 16        |
|    fps             | 1907      |
|    time_elapsed    | 0         |
|    total_timesteps | 1200      |
| train/             |           |
|    actor_loss      | 60.4      |
|    critic_loss     | 1.32e+03  |
|    ent_coef        | 1         |
|    ent_coef_loss   | 0.00625   |
|    learning_rate   | 0.000145  |
|    n_updates       | 24        |
----------------------------------
----------------------------------
| reward/            |           |
|    reward_contact  | 0         |
|    reward_ctrl     | -33.3     |
|    reward_forward  | -1.88     |
|    reward_survive  | 4.38      |
| rollout/          

----------------------------------
| eval/              |           |
|    mean_ep_length  | 58        |
|    mean_reward     | -3.36e+03 |
| reward/            |           |
|    reward_contact  | -0.126    |
|    reward_ctrl     | -57       |
|    reward_forward  | -2.21     |
|    reward_survive  | 5         |
| time/              |           |
|    total_timesteps | 80000     |
| train/             |           |
|    actor_loss      | 833       |
|    critic_loss     | 1.01e+03  |
|    ent_coef        | 0.0277    |
|    ent_coef_loss   | -2.64     |
|    learning_rate   | 0.000999  |
|    n_updates       | 9874      |
----------------------------------


----------------------------------
| reward/            |           |
|    reward_contact  | -0.0814   |
|    reward_ctrl     | -58.6     |
|    reward_forward  | -3.32     |
|    reward_survive  | 4.38      |
| rollout/           |           |
|    ep_len_mean     | 65.1      |
|    ep_rew_mean     | -3.89e+03 |
| time/              |           |
|    episodes        | 1056      |
|    fps             | 1189      |
|    time_elapsed    | 67        |
|    total_timesteps | 80240     |
| train/             |           |
|    actor_loss      | 876       |
|    critic_loss     | 1.17e+03  |
|    ent_coef        | 0.0282    |
|    ent_coef_loss   | 13.5      |
|    learning_rate   | 0.000999  |
|    n_updates       | 9904      |
----------------------------------
----------------------------------
| reward/            |           |
|    reward_contact  | 0         |
|    reward_ctrl     | -61.3     |
|    reward_forward  | -1.19     |
|    reward_survive  | 4.38      |
| rollout/          

----------------------------------
| eval/              |           |
|    mean_ep_length  | 43        |
|    mean_reward     | -1.19e+03 |
| reward/            |           |
|    reward_contact  | -0.357    |
|    reward_ctrl     | -31.3     |
|    reward_forward  | -3.42     |
|    reward_survive  | 5         |
| time/              |           |
|    total_timesteps | 160000    |
| train/             |           |
|    actor_loss      | 1.13e+03  |
|    critic_loss     | 962       |
|    ent_coef        | 0.156     |
|    ent_coef_loss   | 0.305     |
|    learning_rate   | 0.000996  |
|    n_updates       | 19874     |
----------------------------------


----------------------------------
| reward/            |           |
|    reward_contact  | 0         |
|    reward_ctrl     | -31.8     |
|    reward_forward  | -3.95     |
|    reward_survive  | 4.38      |
| rollout/           |           |
|    ep_len_mean     | 43.3      |
|    ep_rew_mean     | -1.59e+03 |
| time/              |           |
|    episodes        | 2544      |
|    fps             | 1189      |
|    time_elapsed    | 134       |
|    total_timesteps | 160064    |
| train/             |           |
|    actor_loss      | 1.12e+03  |
|    critic_loss     | 912       |
|    ent_coef        | 0.157     |
|    ent_coef_loss   | 2.25      |
|    learning_rate   | 0.000996  |
|    n_updates       | 19882     |
----------------------------------
----------------------------------
| reward/            |           |
|    reward_contact  | 0         |
|    reward_ctrl     | -38       |
|    reward_forward  | -5        |
|    reward_survive  | 4.38      |
| rollout/          

----------------------------------
| eval/              |           |
|    mean_ep_length  | 42.4      |
|    mean_reward     | -1.18e+03 |
| reward/            |           |
|    reward_contact  | 0         |
|    reward_ctrl     | -44.7     |
|    reward_forward  | -4.4      |
|    reward_survive  | 5         |
| time/              |           |
|    total_timesteps | 240000    |
| train/             |           |
|    actor_loss      | 824       |
|    critic_loss     | 259       |
|    ent_coef        | 0.106     |
|    ent_coef_loss   | 2.43      |
|    learning_rate   | 0.000991  |
|    n_updates       | 29874     |
----------------------------------


----------------------------------
| reward/            |           |
|    reward_contact  | -1.25     |
|    reward_ctrl     | -32       |
|    reward_forward  | -4.39     |
|    reward_survive  | 3.75      |
| rollout/           |           |
|    ep_len_mean     | 42.8      |
|    ep_rew_mean     | -1.59e+03 |
| time/              |           |
|    episodes        | 4408      |
|    fps             | 1211      |
|    time_elapsed    | 198       |
|    total_timesteps | 240152    |
| train/             |           |
|    actor_loss      | 774       |
|    critic_loss     | 209       |
|    ent_coef        | 0.107     |
|    ent_coef_loss   | -0.543    |
|    learning_rate   | 0.000991  |
|    n_updates       | 29893     |
----------------------------------
----------------------------------
| reward/            |           |
|    reward_contact  | 0         |
|    reward_ctrl     | -33.9     |
|    reward_forward  | -4.93     |
|    reward_survive  | 4.38      |
| rollout/          

---------------------------------
| eval/              |          |
|    mean_ep_length  | 43       |
|    mean_reward     | -466     |
| reward/            |          |
|    reward_contact  | 0        |
|    reward_ctrl     | -15.1    |
|    reward_forward  | -4.88    |
|    reward_survive  | 5        |
| time/              |          |
|    total_timesteps | 320000   |
| train/             |          |
|    actor_loss      | 584      |
|    critic_loss     | 147      |
|    ent_coef        | 0.125    |
|    ent_coef_loss   | -2.01    |
|    learning_rate   | 0.000983 |
|    n_updates       | 39874    |
---------------------------------


---------------------------------
| reward/            |          |
|    reward_contact  | -1.25    |
|    reward_ctrl     | -14.6    |
|    reward_forward  | -4.24    |
|    reward_survive  | 4.38     |
| rollout/           |          |
|    ep_len_mean     | 42.7     |
|    ep_rew_mean     | -519     |
| time/              |          |
|    episodes        | 6248     |
|    fps             | 1208     |
|    time_elapsed    | 264      |
|    total_timesteps | 320120   |
| train/             |          |
|    actor_loss      | 584      |
|    critic_loss     | 174      |
|    ent_coef        | 0.125    |
|    ent_coef_loss   | 1.89     |
|    learning_rate   | 0.000983 |
|    n_updates       | 39889    |
---------------------------------
---------------------------------
| reward/            |          |
|    reward_contact  | 0        |
|    reward_ctrl     | -13.5    |
|    reward_forward  | -6.95    |
|    reward_survive  | 4.38     |
| rollout/           |          |
|    ep_len_me

---------------------------------
| eval/              |          |
|    mean_ep_length  | 42       |
|    mean_reward     | -242     |
| reward/            |          |
|    reward_contact  | 0        |
|    reward_ctrl     | -7.25    |
|    reward_forward  | -4.54    |
|    reward_survive  | 5        |
| time/              |          |
|    total_timesteps | 400000   |
| train/             |          |
|    actor_loss      | 237      |
|    critic_loss     | 55.5     |
|    ent_coef        | 0.0469   |
|    ent_coef_loss   | -5.1     |
|    learning_rate   | 0.000972 |
|    n_updates       | 49874    |
---------------------------------


---------------------------------
| reward/            |          |
|    reward_contact  | 0        |
|    reward_ctrl     | -6.39    |
|    reward_forward  | -4.87    |
|    reward_survive  | 4.38     |
| rollout/           |          |
|    ep_len_mean     | 42       |
|    ep_rew_mean     | -267     |
| time/              |          |
|    episodes        | 8136     |
|    fps             | 1190     |
|    time_elapsed    | 335      |
|    total_timesteps | 400096   |
| train/             |          |
|    actor_loss      | 240      |
|    critic_loss     | 39.7     |
|    ent_coef        | 0.0468   |
|    ent_coef_loss   | 11.7     |
|    learning_rate   | 0.000972 |
|    n_updates       | 49886    |
---------------------------------
---------------------------------
| reward/            |          |
|    reward_contact  | 0        |
|    reward_ctrl     | -7.51    |
|    reward_forward  | -6.01    |
|    reward_survive  | 4.38     |
| rollout/           |          |
|    ep_len_me

---------------------------------
| eval/              |          |
|    mean_ep_length  | 42       |
|    mean_reward     | -236     |
| reward/            |          |
|    reward_contact  | 0        |
|    reward_ctrl     | -6.18    |
|    reward_forward  | -5.01    |
|    reward_survive  | 5        |
| time/              |          |
|    total_timesteps | 480000   |
| train/             |          |
|    actor_loss      | 158      |
|    critic_loss     | 31.3     |
|    ent_coef        | 0.0209   |
|    ent_coef_loss   | 5.31     |
|    learning_rate   | 0.000959 |
|    n_updates       | 59874    |
---------------------------------


---------------------------------
| reward/            |          |
|    reward_contact  | 0        |
|    reward_ctrl     | -6.79    |
|    reward_forward  | -5.04    |
|    reward_survive  | 4.38     |
| rollout/           |          |
|    ep_len_mean     | 42       |
|    ep_rew_mean     | -264     |
| time/              |          |
|    episodes        | 10032    |
|    fps             | 1192     |
|    time_elapsed    | 402      |
|    total_timesteps | 480072   |
| train/             |          |
|    actor_loss      | 159      |
|    critic_loss     | 9.91     |
|    ent_coef        | 0.021    |
|    ent_coef_loss   | 4.53     |
|    learning_rate   | 0.000959 |
|    n_updates       | 59883    |
---------------------------------
---------------------------------
| reward/            |          |
|    reward_contact  | 0        |
|    reward_ctrl     | -6.19    |
|    reward_forward  | -5.12    |
|    reward_survive  | 4.38     |
| rollout/           |          |
|    ep_len_me

---------------------------------
| eval/              |          |
|    mean_ep_length  | 44       |
|    mean_reward     | -182     |
| reward/            |          |
|    reward_contact  | -1.25    |
|    reward_ctrl     | -4.74    |
|    reward_forward  | -4.05    |
|    reward_survive  | 5        |
| time/              |          |
|    total_timesteps | 560000   |
| train/             |          |
|    actor_loss      | 132      |
|    critic_loss     | 10.5     |
|    ent_coef        | 0.0252   |
|    ent_coef_loss   | -3.87    |
|    learning_rate   | 0.000944 |
|    n_updates       | 69874    |
---------------------------------


---------------------------------
| reward/            |          |
|    reward_contact  | 0        |
|    reward_ctrl     | -4.96    |
|    reward_forward  | -4       |
|    reward_survive  | 3.75     |
| rollout/           |          |
|    ep_len_mean     | 44.1     |
|    ep_rew_mean     | -225     |
| time/              |          |
|    episodes        | 11868    |
|    fps             | 1194     |
|    time_elapsed    | 468      |
|    total_timesteps | 560144   |
| train/             |          |
|    actor_loss      | 142      |
|    critic_loss     | 28.4     |
|    ent_coef        | 0.0249   |
|    ent_coef_loss   | 2.23     |
|    learning_rate   | 0.000944 |
|    n_updates       | 69892    |
---------------------------------
---------------------------------
| reward/            |          |
|    reward_contact  | 0        |
|    reward_ctrl     | -12.2    |
|    reward_forward  | -4.96    |
|    reward_survive  | 4.38     |
| rollout/           |          |
|    ep_len_me

---------------------------------
| eval/              |          |
|    mean_ep_length  | 44       |
|    mean_reward     | -177     |
| reward/            |          |
|    reward_contact  | 0        |
|    reward_ctrl     | -4.77    |
|    reward_forward  | -2.64    |
|    reward_survive  | 5        |
| time/              |          |
|    total_timesteps | 640000   |
| train/             |          |
|    actor_loss      | 124      |
|    critic_loss     | 14.5     |
|    ent_coef        | 0.0152   |
|    ent_coef_loss   | -0.127   |
|    learning_rate   | 0.000926 |
|    n_updates       | 79874    |
---------------------------------


---------------------------------
| reward/            |          |
|    reward_contact  | -0.163   |
|    reward_ctrl     | -6.77    |
|    reward_forward  | -4.58    |
|    reward_survive  | 4.38     |
| rollout/           |          |
|    ep_len_mean     | 44.7     |
|    ep_rew_mean     | -198     |
| time/              |          |
|    episodes        | 13676    |
|    fps             | 1201     |
|    time_elapsed    | 532      |
|    total_timesteps | 640096   |
| train/             |          |
|    actor_loss      | 119      |
|    critic_loss     | 30.4     |
|    ent_coef        | 0.0152   |
|    ent_coef_loss   | -3.07    |
|    learning_rate   | 0.000926 |
|    n_updates       | 79886    |
---------------------------------
---------------------------------
| reward/            |          |
|    reward_contact  | 0        |
|    reward_ctrl     | -7.31    |
|    reward_forward  | -2.65    |
|    reward_survive  | 3.75     |
| rollout/           |          |
|    ep_len_me

---------------------------------
| eval/              |          |
|    mean_ep_length  | 44       |
|    mean_reward     | -166     |
| reward/            |          |
|    reward_contact  | -0.0199  |
|    reward_ctrl     | -5.18    |
|    reward_forward  | -4.13    |
|    reward_survive  | 5        |
| time/              |          |
|    total_timesteps | 720000   |
| train/             |          |
|    actor_loss      | 120      |
|    critic_loss     | 27.8     |
|    ent_coef        | 0.0184   |
|    ent_coef_loss   | 3.22     |
|    learning_rate   | 0.000906 |
|    n_updates       | 89874    |
---------------------------------


---------------------------------
| reward/            |          |
|    reward_contact  | -0.042   |
|    reward_ctrl     | -4.83    |
|    reward_forward  | -3.17    |
|    reward_survive  | 4.38     |
| rollout/           |          |
|    ep_len_mean     | 44       |
|    ep_rew_mean     | -190     |
| time/              |          |
|    episodes        | 15416    |
|    fps             | 1207     |
|    time_elapsed    | 596      |
|    total_timesteps | 720040   |
| train/             |          |
|    actor_loss      | 122      |
|    critic_loss     | 25.8     |
|    ent_coef        | 0.0184   |
|    ent_coef_loss   | 4.28     |
|    learning_rate   | 0.000906 |
|    n_updates       | 89879    |
---------------------------------
---------------------------------
| reward/            |          |
|    reward_contact  | -0.0107  |
|    reward_ctrl     | -5.04    |
|    reward_forward  | -5.35    |
|    reward_survive  | 4.38     |
| rollout/           |          |
|    ep_len_me

---------------------------------
| eval/              |          |
|    mean_ep_length  | 45       |
|    mean_reward     | -138     |
| reward/            |          |
|    reward_contact  | 0        |
|    reward_ctrl     | -4.1     |
|    reward_forward  | -2.03    |
|    reward_survive  | 4.38     |
| time/              |          |
|    total_timesteps | 800000   |
| train/             |          |
|    actor_loss      | 124      |
|    critic_loss     | 19.3     |
|    ent_coef        | 0.0227   |
|    ent_coef_loss   | 4.55     |
|    learning_rate   | 0.000884 |
|    n_updates       | 99874    |
---------------------------------


---------------------------------
| reward/            |          |
|    reward_contact  | 0        |
|    reward_ctrl     | -6.84    |
|    reward_forward  | -5.66    |
|    reward_survive  | 4.38     |
| rollout/           |          |
|    ep_len_mean     | 45.6     |
|    ep_rew_mean     | -187     |
| time/              |          |
|    episodes        | 17212    |
|    fps             | 1211     |
|    time_elapsed    | 660      |
|    total_timesteps | 800200   |
| train/             |          |
|    actor_loss      | 129      |
|    critic_loss     | 8.88     |
|    ent_coef        | 0.0227   |
|    ent_coef_loss   | -3.96    |
|    learning_rate   | 0.000884 |
|    n_updates       | 99899    |
---------------------------------
---------------------------------
| reward/            |          |
|    reward_contact  | 0        |
|    reward_ctrl     | -4.14    |
|    reward_forward  | -2.4     |
|    reward_survive  | 4.38     |
| rollout/           |          |
|    ep_len_me

----------------------------------
| eval/              |           |
|    mean_ep_length  | 2.69e+03  |
|    mean_reward     | -3.63e+03 |
| reward/            |           |
|    reward_contact  | 0         |
|    reward_ctrl     | -6.56     |
|    reward_forward  | -0.00674  |
|    reward_survive  | 5         |
| time/              |           |
|    total_timesteps | 880000    |
| train/             |           |
|    actor_loss      | 143       |
|    critic_loss     | 67.4      |
|    ent_coef        | 0.0334    |
|    ent_coef_loss   | -1.76     |
|    learning_rate   | 0.00086   |
|    n_updates       | 109874    |
----------------------------------


---------------------------------
| eval/              |          |
|    mean_ep_length  | 2.02e+04 |
|    mean_reward     | 2.13e+03 |
| reward/            |          |
|    reward_contact  | 0        |
|    reward_ctrl     | -4.66    |
|    reward_forward  | -0.0623  |
|    reward_survive  | 5        |
| time/              |          |
|    total_timesteps | 960000   |
| train/             |          |
|    actor_loss      | 254      |
|    critic_loss     | 55.7     |
|    ent_coef        | 0.0555   |
|    ent_coef_loss   | 2.1      |
|    learning_rate   | 0.000833 |
|    n_updates       | 119874   |
---------------------------------


----------------------------------
| eval/              |           |
|    mean_ep_length  | 3.64e+04  |
|    mean_reward     | -7.68e+04 |
| reward/            |           |
|    reward_contact  | 0         |
|    reward_ctrl     | -4.38     |
|    reward_forward  | -0.000446 |
|    reward_survive  | 5         |
| time/              |           |
|    total_timesteps | 1040000   |
| train/             |           |
|    actor_loss      | 262       |
|    critic_loss     | 75.8      |
|    ent_coef        | 0.0432    |
|    ent_coef_loss   | -0.801    |
|    learning_rate   | 0.000804  |
|    n_updates       | 129874    |
----------------------------------


KeyboardInterrupt: 

In [None]:
# resume training
model = SAC.load("galaxea_sac_lr_forward", env=train_env, tensorboard_log=log_dir)
model.load_replay_buffer("my_buffer.pkl")
model.learning_starts = 0  # 继续训练时，不需要再等待采集数据了,因为已经用了之前的replaybuffer
model.learning_rate = 1e-4
model.learn(total_timesteps=240000,
            tb_log_name="sac",
            progress_bar=True,
            callback=[eval_callback, reward_info_cb])


In [None]:
train_env.close()

### 测试模型可视化

In [None]:
# 加载模型
model = SAC.load("./tb_log/best_model4/best_model.zip")
# 创建测试环境
visual_env = gym.make("galaxea_r1Pro", render_mode="human")

for i in range(5):
    # 测试模型
    obs, info = visual_env.reset()
    cum_reward = 0
    for _ in tqdm(range(1500)):
        visual_env.render()
        action, _ = model.predict(obs, deterministic=True)
        next_obs, reward, terminated, truncated, info = visual_env.step(action)
        cum_reward += reward
        if terminated or truncated:
            print("累积奖励: ", cum_reward)
            break
            
        obs = next_obs

visual_env.close()

### 环境debug

In [None]:
# 可视化环境
train_env = gym.make("galaxea_r1Pro", render_mode="human")
unwrapped_env = train_env.unwrapped

model = unwrapped_env.model  # MjModel

print(f"obs space: {train_env.observation_space.shape}, action space: {train_env.action_space.shape}")
print(f"action range: {train_env.action_space.low} to {train_env.action_space.high}")


print(f"actuator size: {model.nu}, ctrl_size: {unwrapped_env.data.ctrl.shape}")  # actuators and muscles
print(f"obs space: {train_env.observation_space.shape}, action space: {train_env.action_space.shape}")


# 随机采样动作
obs, _ = train_env.reset()
for _ in tqdm(range(1000)):
    train_env.render()
    action = train_env.action_space.sample()
    obs, reward, terminated, truncated, info = train_env.step(action)
    if terminated or truncated:
        obs, _ = train_env.reset()
train_env.close()


In [None]:
train_env = gym.make("galaxea_r1Pro")
obs, _ = train_env.reset()
unwrapped_env = train_env.unwrapped

mj_model = unwrapped_env.model  # MjModel

print(f"qpos size: {mj_model.nq}, qvel size: {mj_model.nv}, num_joints: {mj_model.njnt}")  # 都是旋转关节，所以这一项都相同
print(f"actuator size: {mj_model.nu}, ctrl_size: {unwrapped_env.data.ctrl.shape}")  # actuators and muscle
print(f"body_size: {mj_model.nbody}, body pos size: {unwrapped_env.data.xipos.shape}")  # nbody, 3

# print(f"action range: {env.action_space.low} to {env.action_space.high}")


qpos_idx = 0
for joint_id in range(mj_model.njnt):
    joint_name = mujoco.mj_id2name(mj_model, mujoco.mjtObj.mjOBJ_JOINT, joint_id)
    joint_type = mj_model.jnt_type[joint_id]
    
    # 根据关节类型确定占用的 qpos 数量
    if joint_type == mujoco.mjtJoint.mjJNT_FREE:    # 自由关节：7个qpos (x,y,z,qw,qx,qy,qz)
        for i, coord in enumerate(['x', 'y', 'z', 'qw', 'qx', 'qy', 'qz']):
            print(f"qpos[{qpos_idx:2d}]: {joint_name}_{coord}")
            qpos_idx += 1
    elif joint_type == mujoco.mjtJoint.mjJNT_HINGE:  # 铰链关节：1个qpos
        print(f"qpos[{qpos_idx:2d}]: {joint_name}")
        qpos_idx += 1
    elif joint_type == mujoco.mjtJoint.mjJNT_SLIDE:  # 滑动关节：1个qpos  
        print(f"qpos[{qpos_idx:2d}]: {joint_name}")
        qpos_idx += 1


data  = mujoco.MjData(mj_model)
mujoco.mj_forward(mj_model, data)  # 必须有这一步
pos = data.xipos            # shape = (nbody, 3)
x, y, z = pos[:, 0], pos[:, 1], pos[:, 2]
names = [mujoco.mj_id2name(mj_model, mujoco.mjtObj.mjOBJ_BODY, i)
         for i in range(mj_model.nbody)]

for i, name in enumerate(names):
    print(f"body[{i:2d}]: {name}, pos=({x[i]:.3f}, {y[i]:.3f}, {z[i]:.3f})")


train_env.close()