In [1]:
# 安装依赖（首次运行时执行一次）
!pip install gym[atari] ale-py autorom numpy==1.23.5
!AutoROM --accept-license

AutoROM will download the Atari 2600 ROMs.
They will be installed to:
	/usr/local/lib/python3.11/dist-packages/AutoROM/roms

Existing ROMs will be overwritten.


In [None]:
import gym
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import cv2
from tqdm import trange


# 图像预处理（转换为灰度，缩放）
def preprocess(obs):
    obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
    obs = cv2.resize(obs, (84, 84))
    return obs.astype(np.float32) / 255.0

# 帧堆叠
def stack_frames(frames, state, is_new=False):
    if is_new:
        frames = deque([state] * 4, maxlen=4)
    else:
        frames.append(state)
    return np.stack(frames, axis=0), frames

# 经验回放
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, s, a, r, s_, done):
        self.buffer.append((s, a, r, s_, done))

    def sample(self, batch_size):
        samples = random.sample(self.buffer, batch_size)
        s, a, r, s_, d = zip(*samples)
        return (np.array(s), np.array(a), np.array(r, dtype=np.float32),
                np.array(s_), np.array(d, dtype=np.uint8))

    def __len__(self):
        return len(self.buffer)


class DQN(nn.Module):
    def __init__(self, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512), nn.ReLU(),
            nn.Linear(512, action_dim)
        )

    def forward(self, x):
      if not isinstance(x, torch.Tensor):
          x = torch.tensor(x, dtype=torch.float32)
      return self.net(x.to(device))



# 设置参数
env = gym.make("Adventure-v4", render_mode=None)
action_dim = env.action_space.n

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
policy_net = DQN(action_dim).to(device)
target_net = DQN(action_dim).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.Adam(policy_net.parameters(), lr=1e-4)
replay_buffer = ReplayBuffer(100_000)

epsilon_start = 1.0
epsilon_final = 0.1
epsilon_decay = 30000
batch_size = 32
gamma = 0.99
sync_freq = 1000
frame_limit = 500_000

# 线性衰减epsilon
def epsilon_by_frame(frame_idx):
    return epsilon_final + (epsilon_start - epsilon_final) * np.exp(-1. * frame_idx / epsilon_decay)

# 主训练循环
state = preprocess(env.reset()[0])
state_stack, frames = stack_frames(None, state, is_new=True)

all_rewards = []
episode_reward = 0
frame_idx = 0

for episode in trange(1000):
    state = preprocess(env.reset()[0])
    state_stack, frames = stack_frames(frames, state, is_new=True)
    episode_reward = 0

    for _ in range(10000):
        epsilon = epsilon_by_frame(frame_idx)
        frame_idx += 1

        if random.random() < epsilon:
            action = env.action_space.sample()
        else:
            with torch.no_grad():
                q_vals = policy_net(np.expand_dims(state_stack, axis=0))
                action = q_vals.argmax().item()

        next_state, reward, done, _, _ = env.step(action)
        next_state = preprocess(next_state)
        next_stack, frames = stack_frames(frames, next_state, is_new=False)

        replay_buffer.push(state_stack, action, reward, next_stack, done)
        state_stack = next_stack
        episode_reward += reward

        # 训练网络
        if len(replay_buffer) > 10_000:
            s, a, r, s_, d = replay_buffer.sample(batch_size)
            s = torch.FloatTensor(s).to(device)
            a = torch.LongTensor(a).to(device)
            r = torch.FloatTensor(r).to(device)
            s_ = torch.FloatTensor(s_).to(device)
            d = torch.FloatTensor(d).to(device)

            q_values = policy_net(s)
            next_q_values = target_net(s_)

            q_value = q_values.gather(1, a.unsqueeze(1)).squeeze(1)
            next_q_value = next_q_values.max(1)[0]
            expected_q = r + gamma * next_q_value * (1 - d)

            loss = nn.MSELoss()(q_value, expected_q.detach())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if frame_idx % sync_freq == 0:
            target_net.load_state_dict(policy_net.state_dict())

        if done:
            break

    all_rewards.append(episode_reward)
    print(f"Episode {episode}: Reward = {episode_reward}")

torch.save(policy_net.state_dict(), "dqn_adventure.pth")
print("✅ 模型已保存")


  0%|          | 1/1000 [00:06<1:41:13,  6.08s/it]

Episode 0: Reward = 0.0


  0%|          | 2/1000 [00:26<4:03:27, 14.64s/it]

Episode 1: Reward = 0.0


  0%|          | 3/1000 [00:31<2:47:37, 10.09s/it]

Episode 2: Reward = 0.0


  0%|          | 4/1000 [00:47<3:24:56, 12.35s/it]

Episode 3: Reward = 0.0


  0%|          | 5/1000 [00:59<3:25:44, 12.41s/it]

Episode 4: Reward = 0.0


  1%|          | 6/1000 [01:16<3:50:05, 13.89s/it]

Episode 5: Reward = 0.0


  1%|          | 7/1000 [02:13<7:44:18, 28.05s/it]

Episode 6: Reward = 0.0


  1%|          | 8/1000 [02:30<6:44:09, 24.44s/it]

Episode 7: Reward = 0.0


  1%|          | 9/1000 [02:56<6:51:40, 24.92s/it]

Episode 8: Reward = 0.0


  1%|          | 10/1000 [04:06<10:43:58, 39.03s/it]

Episode 9: Reward = 0.0


  1%|          | 11/1000 [05:18<13:27:14, 48.97s/it]

Episode 10: Reward = 0.0


  1%|          | 12/1000 [05:34<10:40:57, 38.92s/it]

Episode 11: Reward = 0.0


  1%|▏         | 13/1000 [06:47<13:28:46, 49.17s/it]

Episode 12: Reward = 0.0


  1%|▏         | 14/1000 [07:14<11:37:45, 42.46s/it]

Episode 13: Reward = 0.0


  2%|▏         | 15/1000 [07:35<9:54:20, 36.20s/it] 

Episode 14: Reward = 0.0


  2%|▏         | 16/1000 [08:49<12:58:52, 47.49s/it]

Episode 15: Reward = 0.0


  2%|▏         | 17/1000 [10:02<15:04:58, 55.24s/it]

Episode 16: Reward = 0.0


  2%|▏         | 18/1000 [10:35<13:12:05, 48.40s/it]

Episode 17: Reward = 0.0


  2%|▏         | 19/1000 [10:47<10:13:49, 37.54s/it]

Episode 18: Reward = 0.0
