In [3]:
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.env_checker import check_env
from env import SimpleBallEnv  # 导入 SimpleBallEnv 类

# 初始化环境
env = SimpleBallEnv()

# 检查环境是否符合 Gym 的标准
check_env(env)

# 使用 DummyVecEnv 包装环境（Stable Baselines3 需要向量化环境）
vec_env = DummyVecEnv([lambda: env])

# 初始化 DQN 模型
model = DQN("MlpPolicy", vec_env, verbose=1)

# 训练模型
model.learn(total_timesteps=10000)

# 测试模型
state, info = vec_env.reset(), {}
end = False
total_reward = 0

print("\n--- Testing the trained model ---\n")

while not end:
    vec_env.envs[0].render()
    action = model.predict(state)[0]
    state, reward, done, truncated = vec_env.step(action)
    total_reward += reward

    end = done

print(f"\nTotal Reward: {total_reward}")

model.save("dqn_simple_ball")

# 关闭环境
vec_env.close()

Using cpu device
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.832    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 6001     |
|    time_elapsed     | 0        |
|    total_timesteps  | 177      |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.16     |
|    n_updates        | 19       |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.786    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 5503     |
|    time_elapsed     | 0        |
|    total_timesteps  | 225      |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.687    |
|    n_updates        | 31       |
----------------------------------
----------------------------------
| rollout/            |          |
|  

In [29]:
# 加载保存的模型
model = DQN.load("dqn_simple_ball", env=vec_env)

# 测试模型
state, info = vec_env.reset(), {}
end = False
total_reward = 0

print("\n--- Testing the loaded model ---\n")

while not end:
    vec_env.envs[0].render()
    action = model.predict(state)[0]
    state, reward, done, truncated = vec_env.step(action)
    total_reward += reward

    end = done

print(f"\nTotal Reward: {total_reward}")

# 关闭环境
vec_env.close()


--- Testing the loaded model ---

-----o-----
------o----
-------o---
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
-------o---
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-
--------o--
---------o-

Total Reward: [50.]


In [1]:
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.env_checker import check_env
from snake_env import SnakeEnv   
    
env = SnakeEnv()

# 检查环境是否符合 Gym 的标准
check_env(env)

# 使用 DummyVecEnv 包装环境（Stable Baselines3 需要向量化环境）
vec_env = DummyVecEnv([lambda: env])

# 初始化 DQN 模型
model = DQN("MlpPolicy", vec_env, verbose=1)



Using cpu device


In [2]:
# 训练模型
model.learn(total_timesteps=10000)

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.985    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 11072    |
|    time_elapsed     | 0        |
|    total_timesteps  | 16       |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.958    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 15299    |
|    time_elapsed     | 0        |
|    total_timesteps  | 44       |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.935    |
| time/               |          |
|    episodes         | 12       |
|    fps              | 16068    |
|    time_elapsed     | 0        |
|    total_timesteps  | 68       |
----------------------------------
----------------------------------
| rollout/          

<stable_baselines3.dqn.dqn.DQN at 0x330e682b0>

In [37]:
# 训练模型
model.learn(total_timesteps=10000)


--- Testing the trained model ---

[[  0   0   0 101   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   1   4   3   2   1]
 [  0   0   0   0   0   2   0   0   0   0]
 [  0   0   0   0   0   3   0   0   0   0]
 [  0   0   0   0   0   0   0   0 101   0]
 [  0   0   0   0   0   0   0   0   0   0]]
[(5, 5), (6, 5), (7, 5)]
[[  0   0   0 101   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   1   0   0   0   0]
 [  0   0   0   0   0   2   0   0   0   0]
 [  0   0   0   0   0   3   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0 101   0]
 [  0   0   0   0   0   0   0   0   0   0]]
[(4, 5), (5, 5), (6, 5)]
[[  0   0   0 101   0   0   0   0   0   0]
 [  0   0

In [1]:
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.env_checker import check_env
from snake_env import SnakeEnv

# 初始化环境
env = SnakeEnv()

# 检查环境是否符合 Gym 的标准
check_env(env)

# 使用 DummyVecEnv 包装环境（Stable Baselines3 需要向量化环境）

# 初始化 DQN 模型
model = DQN("MlpPolicy", env, verbose=1)



Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [2]:
# 训练模型
model.learn(total_timesteps=10000)

# 保存模型
model.save("dqn_snake_model")

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 2.5      |
|    ep_rew_mean      | 0        |
|    exploration_rate | 0.991    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 3507     |
|    time_elapsed     | 0        |
|    total_timesteps  | 10       |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.62     |
|    ep_rew_mean      | 0        |
|    exploration_rate | 0.972    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 5959     |
|    time_elapsed     | 0        |
|    total_timesteps  | 29       |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.92     |
|    ep_rew_mean      | 0        |
|    exploration_rate | 0.955    |
| time/               |          |
|    episodes       

In [3]:
# 测试模型
env = SnakeEnv()
state, info = env.reset()
end = False
total_reward = 0

print("\n--- Testing the trained model ---\n")

while not end:
    env.render('auto')
    action = model.predict(state)[0]
    state, reward, done, truncated, _ = env.step(action)
    total_reward += reward
    end = done or truncated

print(f"\nTotal Reward: {total_reward}")

# 关闭环境
env.close()


--- Testing the trained model ---

[H[2J. . . . . . . . . . 
. . . C . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . S . . . . 
. . . . . S . . . . 
. . . . . S . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
[H[2J. . . . . . . . . . 
. . . C . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . S . . . . 
. . . . . S . . . . 
. . . . . S . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
[H[2J. . . . . . . . . . 
. . . C . . . . . . 
. . . . . . . . . . 
. . . . . S . . . . 
. . . . . S . . . . 
. . . . . S . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
[H[2J. . . . . . . . . . 
. . . C . . . . . . 
. . . . . S . . . . 
. . . . . S . . . . 
. . . . . S . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
[H[2J. . . . . . . . . . 
. . . C . S . . . . 
. . . . . S . . . . 
. . . . . S . . . . 
. . .