In [33]:
import gymnasium as gym
import ale_py
import wandb
wandb.login(key="390acfbb12bfaf0cc52b7a946e4db99a58ed88f3")

""" test code """
# env = gym.make("ALE/Alien-v5", render_mode="human")
# observation, info = env.reset()

# for _ in range(500):
#     action = env.action_space.sample()  # 随机动作
#     observation, reward, terminated, truncated, info = env.step(action)

#     if terminated or truncated:
#         observation, info = env.reset()

# env.close()

# state, info = env.reset()
# print (state.shape)
# print (info)

' test code '

In [43]:
import numpy as np
import random
from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim
from gymnasium.wrappers import AtariPreprocessing
from gymnasium.wrappers import FrameStackObservation
from tqdm import tqdm

"""
step1: 预处理环境。
why: atari的obs是210*160*3，应该先处理
how: 1.转灰度 2.压缩 3.堆叠最近的4帧画面(这样能同时展示s和a)
"""

env = gym.make("ALE/Alien-v5",frameskip=1, render_mode="human")

#print
obs, info = env.reset()
print ('obs', obs.shape)
print ('info', info)

# AtaAtariPreprocessing函数详解：
# - frame_skip=4 一个动作保持4帧 -> 1.可以减少计算量 2.让动作效果更加明显， 这里atari已经内置了
# - grayscale_obs=True 转化为灰度图像
# - scale_obs=True 把像素值从【0,255】-> 【0,1】
env = AtariPreprocessing(env, grayscale_obs=True, scale_obs=True)

# 表示每次obs由最近的4帧图形拼接而成
# shape（4,210,160）
env = FrameStackObservation(env, stack_size=4)

# print 
obs, info = env.reset()
print ('obs', obs.shape)
print ('info', info)
print (env.action_space.n)
print(env.spec)              # 查看环境配置

obs (210, 160, 3)
info {'lives': 3, 'episode_frame_number': 0, 'frame_number': 0}
obs (4, 84, 84)
info {'lives': 3, 'episode_frame_number': 18, 'frame_number': 18}
18
EnvSpec(id='ALE/Alien-v5', entry_point='ale_py.env:AtariEnv', reward_threshold=None, nondeterministic=False, max_episode_steps=None, order_enforce=True, disable_env_checker=False, kwargs={'game': 'alien', 'repeat_action_probability': 0.25, 'full_action_space': False, 'frameskip': 1, 'max_num_frames_per_episode': 108000, 'render_mode': 'human'}, namespace='ALE', name='Alien', version=5, additional_wrappers=(WrapperSpec(name='AtariPreprocessing', entry_point='gymnasium.wrappers.atari_preprocessing:AtariPreprocessing', kwargs={'noop_max': 30, 'frame_skip': 4, 'screen_size': 84, 'terminal_on_life_loss': False, 'grayscale_obs': True, 'grayscale_newaxis': False, 'scale_obs': True}), WrapperSpec(name='FrameStackObservation', entry_point='gymnasium.wrappers.stateful_observation:FrameStackObservation', kwargs={'stack_size': 4, 'pa

### EnvSpec 逐项解读（重点说明含义与影响）

- id='ALE/Alien-v5'
环境标识：Atari 的 Alien（版本 v5）。

- entry_point='ale_py.env:AtariEnv'
真正创建底层环境的类是 ALE（Arcade Learning Environment）的 AtariEnv。

- kwargs={...}（底层 env 的参数）

- game: 'alien'：游戏名。

- repeat_action_probability: 0.25：sticky actions（黏性动作）概率为 0.25，表示有 25% 概率重复上一个动作，增加环境随机性。

- full_action_space: False：使用 minimal action set 而不是所有动作。

- frameskip: 1：底层 ALE 本身的 frameskip=1（非常重要——不是最终跳帧的意思；wrapper 可能会再做跳帧）。

- max_num_frames_per_episode: 108000：每个 episode 最多 108000 帧（Atari 的标准——约 30 分钟）。

- render_mode: 'human'：渲染模式。

- max_episode_steps=None
表示没有被 TimeLimit（Gym 的 step 上限封装器）包裹，所以 env.spec 的 max_episode_steps 是 None。不过底层的 max_num_frames_per_episode 仍然存在（上面那项）。

- additional_wrappers=(WrapperSpec(...), WrapperSpec(...))
这非常关键 —— Gymnasium 在 make("ALE/Alien-v5") 时自动为你套了两个 wrapper（你通常不需要再手动套一次）：

**AtariPreprocessing 的 kwargs:**

- noop_max=30：reset 时会随机做 0–30 个 NOOP，用于打乱起始状态。

- frame_skip=4：AtariPreprocessing 会把每个动作重复执行 4 帧 —— 这就是常说的跳帧（把 60 FPS 降为 15 FPS 的效果）。

- screen_size=84：会把屏幕缩到 84×84。

- terminal_on_life_loss=False：失去一条命不会把 episode 标记为 terminated（很多实现可选这个行为）。

- grayscale_obs=True：转灰度。

- grayscale_newaxis=False：灰度不会被加成单独的最后轴（意味着单帧是 2D (H,W) 而不是 (H,W,1)）。

- scale_obs=True：把像素归一化到 [0,1]（仍为 uint8、0–255）。

**FrameStackObservation 的 kwargs:**

- stack_size=4：堆叠最近 4 帧 → 最终 observation 包含 4 帧历史。

- padding_type='reset'：在 episode 开始时，空的前帧用 reset 的观测填充（而不是用 0）。

- 注意：顺序通常是先 AtariPreprocessing（做灰度/resize/跳帧），再 FrameStackObservation（在预处理后的帧上做堆叠）。

- vector_entry_point='ale_py.vector_env:AtariVectorEnv'
- 环境支持 vectorized（并行）版本，用于同时跑多个 env。

In [50]:
"""
class DQN
input : state s and action a
        in other words, 4-frame obs (N, 4, 84, 84)
return : q(s,a)
"""
class DQN(nn.Module):
    def __init__(self, action_dim):
        super(DQN, self).__init__()
        self.net = nn.Sequential(

            nn.Conv2d(4, 32, kernel_size=8, stride=4), ## output=(N, 32, 20, 20)
            nn.ReLU(),
            nn.Conv2d(32,64, kernel_size=4, stride=2), ## output=(N, 64, 9, 9)
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1), ## output = (N, 64, 7, 7)
            nn.ReLU(),
            nn.Flatten(), ## 64*7*7 = 3136
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, action_dim) # 对应每个a的q值
        )

    def forward(self, x):
        y = self.net(x)
        return y

"""
test DQN
input: a dummy tensor to test DQN
"""
def test_DQN():
    print ("===test DQN===")
    dqn = DQN(env.action_space.n)
    dummy_x = torch.randn((1,4,84,84))
    print (dummy_x.shape)
    y = dqn(dummy_x)
    print (y.shape)
    print (y)
    print ("===test DQN===\n")

test_DQN()


"""
class ReplayBuffer
why: 1.打破数据的事件关联性 2.增强样本的利用率 3.提高训练的稳定性
how: 把经验放在deque容器buffer中， 需要的时候随机采样batch_size个
"""
class ReplayBuffer:
    def __init__(self, capacity=100000):
        self.buffer = deque(maxlen=capacity) # 双端队列deque, 如果满了， 最老的经验会被删除

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size) #从buffer中采样 batch_size 个样本

        # *是解包操作unpacking, 这是把batch解包后， 按列组合
        state, action, reward, next_state, done = map(np.array, zip(*batch)) 
        return state, action, reward, next_state, done

    def __len__(self):
        return len(self.buffer)


"""
关键参数赋值
"""
epsilon_start, epsilon_end, epsilon_decay = 1.0, 0.1, 100000
gamma = 0.98
batch_size = 32 # 一次梯度下降用这么多数量的数据
update_target = 5000 # 更新target的频率

num_episodes = 5000
returns = []



"""
def train()
封装整个训练过程
"""
def train(env):
    device = torch.device('cuda')
    action_dim = env.action_space.n
    global_step = 0

    #qnet and target_net
    qnet = DQN(action_dim).to(device)
    target_net = DQN(action_dim).to(device)

    # 将qnet 的参数传入 target_net
    # why两个net ： 1.虽然两个都是估计q(s,a) 
    #               2.先固定target_net，对损失函数求偏导的时候就不会太复杂，然后再把更新后的参数赋给targetnet
    #               3. 这样还可以稳定训练，减少 Q 值振荡
    target_net.load_state_dict(qnet.state_dict())

    # 创建实例
    optimizer = optim.Adam(qnet.parameters(), lr=1e-4)
    buffer = ReplayBuffer()

    for episode in tqdm(range(num_episodes)):
        state, _ = env.reset()
        state = np.array(state)

        # max = num_episodes * steps
        # steps 最大为 108000 / 4, 根据 env.spec()查看得 'max_num_frames_per_episode': 108000
        # 上面有设置 skip_frame = 4, 
        episode_reward = 0

        done = False
        while not done:

            # 用 epsilon-greeedy 策略采样episode， 刚开始多exploration, 后来多exploitation
            # global_step = 0 -> epsilon = 1.0
            # global_step = 正无穷 -> epsilon = 0.1
            epsilon = epsilon_end + (epsilon_start- epsilon_end) * np.exp(-global_step / epsilon_decay)

            if np.random.random() < epsilon:
                action = env.action_space.sample()
            else:
                with torch.no_grad():
                    s = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)  # (1, 4, 84, 84)
                    q_values = qnet(s)
                    action = q_values.argmax(1).item() # q_value (1,18) -> argmax(1)指定在action space 维度

            next_state, reward, terminated, truncated, _ = env.step(action)
            next_state = np.array(next_state)
            done = terminated or truncated

            # 放入经验 buffer 中
            buffer.push(state, action, reward, next_state, done)

            #更新状态
            state = next_state
            episode_reward += reward
            global_step += 1


            # 用经验回放来训练
            if len(buffer) > batch_size:
                s, a, r, ns, d = buffer.sample(batch_size)

                #放在gpu上
                s = torch.tensor(s, dtype=torch.float32, device=device)  
                ns = torch.tensor(ns, dtype=torch.float32, device=device)
                a = torch.tensor(a, dtype=torch.long, device=device)     
                r = torch.tensor(r, dtype=torch.float32, device=device)  
                d = torch.tensor(d, dtype=torch.float32, device=device)

                # gather(dim, idx): dim=1表示在行上， 选取idx列的数据
                # a = [0, 2, 1],  a.unsqueeze(1) = [[0], [2], [1]]
                # 本质是从每个 batch 中选出采样到 qvalue
                q_values = qnet(s).gather(1, a.unsqueeze(1)).squeeze(1)

                with torch.no_grad():
                    # 在计算 target_q 的时候要锁住梯度计算， 不然会很复杂
                    max_next_q = target_net(ns).max(1)[0]
                    target_q = r + gamma * (1 - d) * max_next_q

                loss = nn.MSELoss()(q_values, target_q)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if global_step % update_target == 0:
                target_net.load_state_dict(qnet.state_dict())

        returns.append(episode_reward)
        print (f"Episode {episode}, Return {episode_reward}, Epsilon {epsilon:.3f}")
                

===test DQN===
torch.Size([1, 4, 84, 84])
torch.Size([1, 18])
tensor([[ 0.0149,  0.0533,  0.0003, -0.0331,  0.0247,  0.0067, -0.0500, -0.0023,
          0.0110, -0.0395,  0.0067,  0.0283,  0.0028, -0.0329,  0.0209,  0.0276,
          0.0064, -0.0637]], grad_fn=<AddmmBackward0>)
===test DQN===



In [51]:
import wandb

wandb.init(
    project="dqn-atari",       
    entity="Mingyu Liu",    
    config={
        "env": "ALE/Alien-v5",
        "batch_size": 32,
        "gamma": 0.99,
        "lr": 1e-4,
        "epsilon_start": 1.0,
        "epsilon_end": 0.1
    }
)
config = wandb.config


train(env)

  0%|                                        | 1/5000 [00:42<58:33:01, 42.16s/it]

Episode 0, Return 130.0, Epsilon 0.995


  0%|                                        | 2/5000 [01:47<77:41:05, 55.96s/it]

Episode 1, Return 200.0, Epsilon 0.988


  0%|                                        | 3/5000 [02:33<71:24:19, 51.44s/it]

Episode 2, Return 150.0, Epsilon 0.983


  0%|                                       | 4/5000 [04:30<107:09:32, 77.22s/it]

Episode 3, Return 260.0, Epsilon 0.975


  0%|                                        | 5/5000 [05:26<96:15:24, 69.37s/it]

Episode 4, Return 200.0, Epsilon 0.970


  0%|                                        | 6/5000 [06:19<88:49:21, 64.03s/it]

Episode 5, Return 160.0, Epsilon 0.964


  0%|                                        | 7/5000 [06:41<69:51:28, 50.37s/it]

Episode 6, Return 140.0, Epsilon 0.959


  0%|                                        | 8/5000 [07:02<56:50:48, 41.00s/it]

Episode 7, Return 170.0, Epsilon 0.954


  0%|                                        | 9/5000 [07:30<51:12:28, 36.94s/it]

Episode 8, Return 180.0, Epsilon 0.949


  0%|                                       | 10/5000 [09:03<74:55:05, 54.05s/it]

Episode 9, Return 230.0, Epsilon 0.944


  0%|                                       | 11/5000 [10:13<81:36:19, 58.89s/it]

Episode 10, Return 120.0, Epsilon 0.939


  0%|                                       | 12/5000 [11:33<90:38:57, 65.42s/it]

Episode 11, Return 190.0, Epsilon 0.932


  0%|                                      | 13/5000 [13:03<101:00:25, 72.91s/it]

Episode 12, Return 120.0, Epsilon 0.927


  0%|                                      | 14/5000 [14:35<108:54:12, 78.63s/it]

Episode 13, Return 180.0, Epsilon 0.922


  0%|                                      | 15/5000 [16:32<124:56:55, 90.23s/it]

Episode 14, Return 220.0, Epsilon 0.916


  0%|                                      | 16/5000 [18:11<128:39:09, 92.93s/it]

Episode 15, Return 100.0, Epsilon 0.911


  0%|▏                                    | 17/5000 [20:10<139:16:47, 100.62s/it]

Episode 16, Return 250.0, Epsilon 0.905


  0%|▏                                     | 18/5000 [21:30<130:53:27, 94.58s/it]

Episode 17, Return 190.0, Epsilon 0.900


  0%|▏                                     | 19/5000 [22:35<118:36:02, 85.72s/it]

Episode 18, Return 120.0, Epsilon 0.896


  0%|▏                                     | 20/5000 [24:21<126:55:15, 91.75s/it]

Episode 19, Return 210.0, Epsilon 0.891


  0%|▏                                     | 20/5000 [25:38<106:24:29, 76.92s/it]


KeyboardInterrupt: 