# 深度强化学习 DQN 实验

实验采用 OpenAI Gymnasium 的 Frozen Lake 环境



In [12]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt

In [13]:
BLANK = 0
LAKE = 1
GIFT = 2
PLAYER = 3
    
grid = torch.tensor([
    [BLANK, BLANK, BLANK, BLANK],
    [BLANK, LAKE, BLANK, LAKE],
    [BLANK, BLANK, BLANK, LAKE],
    [LAKE, BLANK, BLANK, GIFT]
], dtype=torch.long)

def player_on_grid(state):
    batch_size = state.size(0)
    ret = grid.unsqueeze(0).repeat(batch_size, 1, 1)
    indices = torch.cat((state.unsqueeze(-1) // 3, state.unsqueeze(-1) % 3), dim=-1)
    for i in range(batch_size):
        ret[i, indices[i,0], indices[i,1]] = PLAYER
    return ret

In [14]:
class DQN(nn.Module):
    def __init__(self, state_size=(4, 4), action_size=4, grid_states=4, d_model=16, lr=1e-3):
        super().__init__()
        self.state_size = state_size
        self.action_size = action_size
        self.grid_states = grid_states
        self.d_model = d_model
        
        self.emb = nn.Embedding(self.grid_states, self.d_model)
        
        self.featnet = nn.Sequential(
            nn.Conv2d(d_model, 2*d_model, kernel_size=3, stride=1, padding=1),          # (4,4,16) -> (4,4,32)
            nn.ReLU(),
            nn.Conv2d(2*d_model, 4*d_model, kernel_size=3, stride=1, padding=1),        # (4,4,32) -> (4,4,64)
            nn.ReLU(),
            nn.Conv2d(4*d_model, 8*d_model, kernel_size=3, stride=1),                   # (4,4,64) -> (2,2,128)
            nn.ReLU()
        )
        
        self.vnet = nn.Sequential(
            nn.Linear(self._feat_size(), self._feat_size() // 8),                        # (512) -> (64)
            nn.ReLU(),
            nn.Linear(self._feat_size() // 8, self.action_size)                          # (64) -> (4)
        )
        
        self.criterion = nn.MSELoss()
        self.optimizer = optim.Adam(self.parameters(), lr=lr)
        
    def _feat_size(self):
        return 8 * (self.state_size[0] - 2) * (self.state_size[1] - 2) * self.d_model
        
    def forward(self, x):
        batch_size = x.size(0)
        x = player_on_grid(x)
        x = self.emb(x)
        x = x.permute(0, 3, 1, 2)
        x = self.featnet(x)
        x = x.reshape(batch_size, -1)
        x = self.vnet(x)
        return x
    
    def act(self, x, eps=0.1):
        # Epsilon-Greedy
        if random.random() > eps:
            with torch.no_grad():
                return torch.argmax(self.forward(x), dim=-1).squeeze().item()
        else:
            return torch.randint(0, self.action_size)
        
    def update(self, sample, gamma=1.0):
        state, action, reward, next_state, done = sample
        
        self.optimizer.zero_grad()
        max_next_q = torch.max(self.forward(next_state), dim=1)[0] if not done else torch.zeros_like(reward)
        target = reward + gamma * max_next_q
        q = self.forward(state).gather(1, action)
        loss = self.criterion(q, target)
        loss.backward()
        self.optimizer.step()
        
dqn = DQN()
dqn(torch.tensor([0, 1, 7]))

tensor([[ 0.0442, -0.0743, -0.0089,  0.0957],
        [ 0.0491, -0.0706, -0.0074,  0.0978],
        [ 0.0420, -0.0793, -0.0050,  0.0978]], grad_fn=<AddmmBackward0>)