In [13]:
import gym
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
from matplotlib import pyplot as plt
import numpy as np
import random
from collections import deque

class DQN(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(DQN, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.model = nn.Sequential(
            nn.Linear(in_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            # nn.Linear(512, 1024),
            # nn.ReLU(),
            # nn.Linear(1024, 512),
            # nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, out_dim)
        )

    def forward(self, x):
        return self.model(x)


class ReplayBuffer(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        experience = [state, action, reward, next_state, done]
        self.buffer.append(experience)
    
    def sample(self, batch_size):
        if len(self.buffer) < batch_size:
            batch = random.sample(self.buffer, len(self.buffer))
        else:
            batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return states, actions, rewards, next_states, dones
    
    def __len__(self):
        return len(self.buffer)


class DQNAgent(object):
    
    def __init__(self, tau=1e-3, buffer_size=10000):
        self.env = gym.make("LunarLander-v2")
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tau = tau
        self.memory = ReplayBuffer(buffer_size)
   
    def solve(self):
        print(self.device)
        in_dim = self.env.observation_space.shape[0]
        out_dim = self.env.action_space.n
        gamma = 0.95
        epsilon = 1
        eps_end = 0.2
        eps_decay = 0.995 
        batch_size=64
        losses = []
        Rewards = [] 
        avg_rewards = []   
        times = []
        model = DQN(in_dim, out_dim).to(self.device)
        #model =  torch.load("model5999.pth", map_location = self.device)
        target_model = DQN(in_dim, out_dim).to(self.device)
        target_model=model # 初始化目标网络参数与主网络参数相同
        target_model.eval()  # 关闭目标网络的训练模式

        loss_fn = nn.MSELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        scheduler = StepLR(optimizer, step_size=20, gamma=0.9)    

        state =  self.env.reset()
        state = torch.from_numpy(state).float().to(self.device)

        epoch = 4000
        epochs = [] 
        for i in range(epoch):
            self.env.reset()
            done = 0
            R=0
            J=0
            C=0
            l = 0

            model.train()
            while done != 1:
                C+=1
                random_num = torch.rand(1).item()
                model.eval()
                with torch.no_grad():
                    action = torch.argmax(model(state)) if random_num > epsilon else torch.randint(high=self.env.action_space.n, size=(1,))[0].item()
                model.train()
                action = int(action)
                next_state, reward, done, _ = self.env.step(action)
                R+=reward
                next_state = torch.from_numpy(next_state).float().to(self.device)
                next_action = torch.argmax(model(next_state))
                q_value = model(state)[action]
                e_q_value = reward + gamma * model(next_state)[next_action].detach() * (1 - done)
              
                l = loss_fn(q_value, e_q_value)
                J+=int(l)
                optimizer.zero_grad()
                l.backward()
                optimizer.step()

                state = next_state
                self.memory.push(state, action, reward, next_state, done)
                
                # 从经验回放缓冲区中采样
                batch = self.memory.sample(batch_size)
                states, actions, rewards, next_states, dones = batch

                # 转换为张量并移动到设备
                states = torch.stack(states).float().to(self.device)
                actions = torch.tensor(actions).to(self.device)
                rewards = torch.tensor(rewards).float().to(self.device)
                next_states = torch.stack(next_states).float().to(self.device)
                dones = torch.tensor(dones).float().to(self.device)

                # 计算当前状态的Q值
                q_values = model(states)

                # 使用目标网络计算下一个状态的Q值
                next_q_values = target_model(next_states).detach()

                # 计算目标Q值
                target_q_values = rewards + gamma * torch.max(next_q_values, dim=1)[0] * (1 - dones)

                # 获取当前状态下实际采取的动作的Q值
                q_values_for_actions = q_values.gather(dim=1, index=actions.unsqueeze(-1)).squeeze(-1)

                # 计算损失
                loss = loss_fn(q_values_for_actions, target_q_values)

                # 反向传播和优化
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # 软更新目标网络参数
                for target_param, param in zip(target_model.parameters(), model.parameters()):
                    target_param.data.copy_(self.tau * param.data + (1.0 - self.tau) * target_param.data)
            
            scheduler.step()
            epsilon = max(eps_end, epsilon*eps_decay)
            losses.append(J/C)
            Rewards.append(R)
            times.append(C)
            if i>20 and i%20==0 and avg_rewards[-15]-np.mean(Rewards[-100:])>5:
                optimizer.param_groups[0]['lr'] = min(optimizer.param_groups[0]['lr']*1.5, 0.0004)

            avg_rewards.append(np.mean(Rewards[-100:]))
            
            epochs.append(i)
            print(f"epoch:{i}  Rewards:{R} avg_reward:{np.mean(Rewards[-100:])} time:{C}, lr:{optimizer.param_groups[0]['lr']}")
      
            if (i+1) %2000 == 0:
                torch.save(model, f'model{i}.pth')
            if R>200 and np.mean(Rewards[-100:])>200:
                torch.save(model, f'good_model{i}.pth')
            self.env.close()
            if (i + 1) % 1000 ==0:
                plt.figure(figsize=(24,6))
                plt.plot(epochs,losses,label="loss")
                plt.xlabel("epoch")
                plt.legend()
                plt.savefig(f"model{i}loss.png", dpi = 400)
                plt.close() 
                plt.figure(figsize=(24,6))
                plt.plot(epochs,Rewards,label="reward")
                plt.plot(epochs,avg_rewards,label = "avg_reward")
                plt.xlabel("epoch")
                plt.legend()
                plt.savefig(f"model{i}reward.png", dpi = 400)
                plt.close()
                plt.figure(figsize=(24,6))                
                plt.plot(epochs,times,label="times")
                plt.xlabel("epoch")
                plt.legend()
                plt.savefig(f"model{i}times.png", dpi = 400)
                plt.close()
        return l
    
        

agent = DQNAgent() 
agent.solve()

cuda
epoch:0  Rewards:-373.4790131338285 avg_reward:-373.4790131338285 time:87, lr:0.001
epoch:1  Rewards:-119.20799028201708 avg_reward:-246.3435017079228 time:66, lr:0.001
epoch:2  Rewards:-125.51986358358612 avg_reward:-206.06895566647722 time:129, lr:0.001
epoch:3  Rewards:-119.39332017132946 avg_reward:-184.4000467926903 time:88, lr:0.001
epoch:4  Rewards:-56.56358612547957 avg_reward:-158.83275465924814 time:61, lr:0.001
epoch:5  Rewards:-78.11347433951462 avg_reward:-145.3795412726259 time:66, lr:0.001
epoch:6  Rewards:-141.51539138580745 avg_reward:-144.82751986022325 time:124, lr:0.001
epoch:7  Rewards:-36.19820505508919 avg_reward:-131.24885550958152 time:107, lr:0.001
epoch:8  Rewards:-108.15660550080452 avg_reward:-128.68304995305073 time:89, lr:0.001
epoch:9  Rewards:-100.05108572061512 avg_reward:-125.81985352980719 time:77, lr:0.001
epoch:10  Rewards:-82.50101289052492 avg_reward:-121.88177710805425 time:61, lr:0.001
epoch:11  Rewards:-138.63676586973992 avg_reward:-123.