In [1]:
from pathlib import Path
import shutil

import gymnasium as gym
import torch

from deep_q import DeepQFunc, DeepQFuncTrainer, DeepQFuncTester, ReplayBuffer
from env import Env


#### 使用CarPole-V1 环境，测试简单的Deep Q 如何处理连续的State空间和离散的Action空间

In [3]:
GYM_ENV_NAME = 'CartPole-v1'
_train_gym_env = gym.make(GYM_ENV_NAME)

# 打印查看环境的动作空间和状态空间 
action_nums, state_space = _train_gym_env.action_space.n, _train_gym_env.observation_space
print(f'action num: {action_nums}, space: {state_space}')

TRAIN_EPOCH = 300
HIDDEN_DIM = 256
LEARNING_RATE = 2e-3
GAMMA = 0.99

# 使用指数递减的epsilon-greedy策略
START_EPSILON = 0.5
END_EPSILON = 0.05
DECAY_RATE = 0.99
EPSILON_LIST = [max(START_EPSILON * (DECAY_RATE ** i), END_EPSILON) for i in range(TRAIN_EPOCH)]


log_path = Path('./logs/run2')
if log_path.exists():
    shutil.rmtree(log_path)

# _USE_CUDA = True and torch.cuda.is_available()
_USE_CUDA = False and torch.cuda.is_available()

q_func = DeepQFunc(state_space.shape[0], 
                   action_nums, 
                   hidden_dim=HIDDEN_DIM, 
                   device=torch.device('cuda') if _USE_CUDA else None)

env = Env(_train_gym_env)

replay_buffer = ReplayBuffer(10000)
q_func_trainer = DeepQFuncTrainer(q_func=q_func, 
                                  env=env,
                                  replay_buffer=replay_buffer,
                                  learning_rate=LEARNING_RATE,
                                  batch_size=64,
                                  gamma=GAMMA,
                                  epsilon_list=EPSILON_LIST,
                                  logger_folder=log_path)


action num: 2, space: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)


In [4]:
q_func_trainer.train(train_epoch=TRAIN_EPOCH, 
                     max_steps=1000, 
                     minimal_replay_size_to_train=64 * 10,
                     target_q_update_freq=10)

  return func(*args, **kwargs)
100%|██████████| 300/300 [02:18<00:00,  2.17it/s, reward=500, step=500]


#### 开始测试

In [5]:
_render_env = Env(gym.make(GYM_ENV_NAME, render_mode='human'))
q_func_tester = DeepQFuncTester(
    q_func=q_func.to('cpu'),
    env=_render_env
)

In [8]:
q_func_tester.test(2000)

Test reward: 500.0
Step Rewards: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1

In [10]:
p = Path('./model')
if not p.exists():
    p.mkdir()
q_func.save(Path('./model/trained_for_cartpole.pth'))

In [11]:
q_func_from_checkpoint = DeepQFunc(state_space.shape[0], 
                   action_nums, 
                   hidden_dim=HIDDEN_DIM)
q_func_from_checkpoint.load(Path('./model/trained_for_cartpole.pth'))

In [12]:
q_func_tester = DeepQFuncTester(
    q_func=q_func_from_checkpoint.to('cpu'),
    env=_render_env
)

In [13]:
q_func_tester.test(2000)

Test reward: 500.0
Step Rewards: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1