In [1]:
!pip install gym[atari] ale-py autorom --quiet
!AutoROM --accept-license --quiet


AutoROM will download the Atari 2600 ROMs.
They will be installed to:
	/usr/local/lib/python3.11/dist-packages/AutoROM/roms

Existing ROMs will be overwritten.


In [None]:
import gym
import numpy as np
import random
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
from tqdm import trange

def preprocess(obs):
    obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
    obs = cv2.resize(obs, (84, 84))
    return obs.astype(np.float32) / 255.0

def stack_frames(frames, new_frame, is_new=False):
    if is_new:
        frames = deque([new_frame] * 4, maxlen=4)
    else:
        frames.append(new_frame)
    return np.stack(frames, axis=0), frames

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, s, a, r, s_, d):
        self.buffer.append((s, a, r, s_, d))

    def sample(self, batch_size):
        samples = random.sample(self.buffer, batch_size)
        s, a, r, s_, d = zip(*samples)
        return (np.array(s), np.array(a), np.array(r, dtype=np.float32),
                np.array(s_), np.array(d, dtype=np.uint8))

    def __len__(self):
        return len(self.buffer)

class DQN(nn.Module):
    def __init__(self, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512), nn.ReLU(),
            nn.Linear(512, action_dim)
        )

    def forward(self, x):
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, dtype=torch.float32)
        return self.net(x.to(device))

env = gym.make("Atlantis-v4", render_mode=None)
action_dim = env.action_space.n

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
policy_net = DQN(action_dim).to(device)
target_net = DQN(action_dim).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.Adam(policy_net.parameters(), lr=1e-4)
replay_buffer = ReplayBuffer(100_000)

epsilon_start = 1.0
epsilon_final = 0.1
epsilon_decay = 30000
batch_size = 32
gamma = 0.99
sync_freq = 1000
frame_limit = 500_000

def epsilon_by_frame(frame_idx):
    return epsilon_final + (epsilon_start - epsilon_final) * np.exp(-1. * frame_idx / epsilon_decay)

frame_idx = 0
all_rewards = []

for episode in trange(1000):
    obs, _ = env.reset()
    state = preprocess(obs)
    state_stack, frames = stack_frames(None, state, is_new=True)
    episode_reward = 0

    for _ in range(10000):
        epsilon = epsilon_by_frame(frame_idx)
        frame_idx += 1

        if random.random() < epsilon:
            action = env.action_space.sample()
        else:
            with torch.no_grad():
                q_vals = policy_net(np.expand_dims(state_stack, axis=0))
                action = q_vals.argmax().item()

        next_obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        next_state = preprocess(next_obs)
        next_stack, frames = stack_frames(frames, next_state, is_new=False)

        replay_buffer.push(state_stack, action, reward, next_stack, done)
        state_stack = next_stack
        episode_reward += reward

        if len(replay_buffer) > 10000:
            s, a, r, s_, d = replay_buffer.sample(batch_size)
            s = torch.tensor(s, dtype=torch.float32).to(device)
            a = torch.tensor(a, dtype=torch.long).to(device)
            r = torch.tensor(r, dtype=torch.float32).to(device)
            s_ = torch.tensor(s_, dtype=torch.float32).to(device)
            d = torch.tensor(d, dtype=torch.float32).to(device)

            q_values = policy_net(s)
            next_q_values = target_net(s_)

            q_value = q_values.gather(1, a.unsqueeze(1)).squeeze(1)
            next_q_value = next_q_values.max(1)[0]
            expected_q = r + gamma * next_q_value * (1 - d)

            loss = nn.MSELoss()(q_value, expected_q.detach())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if frame_idx % sync_freq == 0:
            target_net.load_state_dict(policy_net.state_dict())

        if done:
            break

    all_rewards.append(episode_reward)
    print(f"Episode {episode}: Reward = {episode_reward}")

torch.save(policy_net.state_dict(), "dqn_atlantis.pth")
print("✅ 模型已保存为 dqn_atlantis.pth")

  0%|          | 1/1000 [00:01<21:55,  1.32s/it]

Episode 0: Reward = 17500.0


  0%|          | 2/1000 [00:03<26:06,  1.57s/it]

Episode 1: Reward = 34600.0


  0%|          | 3/1000 [00:04<25:52,  1.56s/it]

Episode 2: Reward = 18900.0


  0%|          | 4/1000 [00:06<29:17,  1.76s/it]

Episode 3: Reward = 30900.0


  0%|          | 5/1000 [00:13<56:45,  3.42s/it]

Episode 4: Reward = 12900.0


  1%|          | 6/1000 [00:28<2:05:23,  7.57s/it]

Episode 5: Reward = 17700.0


  1%|          | 7/1000 [00:54<3:41:46, 13.40s/it]

Episode 6: Reward = 42200.0


  1%|          | 8/1000 [01:09<3:52:40, 14.07s/it]

Episode 7: Reward = 16700.0


  1%|          | 9/1000 [01:19<3:32:35, 12.87s/it]

Episode 8: Reward = 11700.0


  1%|          | 10/1000 [01:28<3:10:25, 11.54s/it]

Episode 9: Reward = 10600.0


  1%|          | 11/1000 [01:41<3:17:53, 12.01s/it]

Episode 10: Reward = 14800.0


  1%|          | 12/1000 [01:51<3:06:20, 11.32s/it]

Episode 11: Reward = 7600.0


  1%|▏         | 13/1000 [01:59<2:51:58, 10.45s/it]

Episode 12: Reward = 12400.0


  1%|▏         | 14/1000 [02:10<2:52:07, 10.47s/it]

Episode 13: Reward = 14300.0


  2%|▏         | 15/1000 [02:23<3:04:02, 11.21s/it]

Episode 14: Reward = 17800.0
