### import和device

In [1]:
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
print(device)

cuda


### Reply Memory

In [29]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):
    
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)
        
    def push(self, state, action, next_state, reward):
        self.memory.append(Transition(state, action, next_state, reward))
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)
    

In [30]:
class DQN(nn.Module):
    def __init__(self, n_states, n_actions):
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(n_states, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)
        
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)
    

In [31]:
trs = [Transition(1, 2, 3, 4), Transition(5, 6, 7, 8)]
trs = Transition(*zip(*trs))
trs

Transition(state=(1, 5), action=(2, 6), next_state=(3, 7), reward=(4, 8))

In [32]:
env = gym.make("CartPole-v1")

### 定义Agent
+ update_q_function()
+ memorize()
+ choose_function

In [79]:
class Agent:
    def __init__(self, n_states, n_actions, eta=0.5, gamma=0.99, capacity=10000, batch_size=32,
                 eps_start=0.9, eps_end=0.05, eps_decay=1000):
        self.n_states = n_states
        self.n_actions = n_actions
        self.eta = eta
        self.gamma = gamma
        self.batch_size = batch_size
        self.eps_start = eps_start
        self.eps_end = eps_end
        self.eps_decay = eps_decay
        
        self.memory = ReplayMemory(capacity)
        
        self.policy_net = DQN(n_states, n_actions).to(device)
        self.target_net = DQN(n_states, n_actions).to(device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=1e-4)
        
        self.steps_done = 0
        
    def optimize_model(self):
        if len(self.memory) < self.batch_size:
            return 
        
        transitions = self.memory.sample(self.batch_size)
        batch = Transition(*zip(*transitions))
        # batch is Transition(state=(*,*,*,*), action=(*,*,*,*), next_state=(*,*,*,*), reward=(*,*,*,*)) 
        # there are 4 stars when batch_size=4
        
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), 
                                      device=device, dtype=torch.bool)
        non_final_next_state_batch = torch.cat([s for s in batch.next_state if s is not None])
        
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        
        
        #pred: Q(state_t, action_t)
        #truth: reward + \gamma * \max_a Q(state_{t+1},a)
        
        state_action_values = self.policy_net(state_batch).gather(dim=1, index=action_batch)
        next_state_values = torch.zeros(self.batch_size, device=device)
        with torch.no_grad():
            next_state_values[non_final_mask] = self.target_net(non_final_next_state_batch).max(dim=1).values
        expected_state_action_values = reward_batch + self.gamma * next_state_values
        
        # x.unsqueeze(i)可以在x的第i维处增加一个维度
        # expected_state_action_values的shapehe原本是(batch_size, )和reward_batch相同
        # 改为(batch_size, 1)，和state_action_values相同
        loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
        
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(self.policy_net.parameters(), 100)
        self.optimizer.step()
        
    def memorize(self, state, action, next_state, reward):
        self.memory.push(state, action, next_state, reward)
    
    def choose_action(self, state):
        eps_threshold = self.eps_end + (self.eps_start - self.eps_end) * \
        math.exp(-1. * self.steps_done / self.eps_decay)
        self.steps_done += 1
        
        if random.random() < eps_threshold:
            # explore
            return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)
        else:
            with torch.no_grad():
                # .max(1)就是取每一行最大的元素。model的输出只有一行，两列，所以也就是取两个数里更大的那个。
                # .indices也可以写成下标[1]。要取的是最优的动作而不是这个动作的价值，所以要的是下标而不是具体的值。
                # view函数是用来reshape的。view(1,1)就是搞成1*1的tensor
                return self.policy_net(state).max(1).indices.view(1,1)
        

### Agent与环境交互（训练过程）

In [80]:
# this cell is for test
state, info = env.reset()
print(state)
state1 = torch.tensor(state, dtype=torch.float32, device=device)
print(state1)
state2 = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
print(state2)
print(state1.shape)
print(state2.shape)


[-0.02023418  0.01963305  0.04826971  0.02622762]
tensor([-0.0202,  0.0196,  0.0483,  0.0262], device='cuda:0')
tensor([[-0.0202,  0.0196,  0.0483,  0.0262]], device='cuda:0')
torch.Size([4])
torch.Size([1, 4])


In [81]:
# this cell is for test
n_states = env.observation_space.shape[0]
n_actions = env.action_space.n
agent = Agent(n_states, n_actions)

state, info = env.reset()
#unsqueeze之前state.shape是[4]，之后是[1,4]
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)

action = agent.choose_action(state)
observation, reward, is_done, _, _ = env.step(action.item())
print(observation)
next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
print(next_state)

[ 0.02329497 -0.23557281 -0.03836909  0.32001054]
tensor([[ 0.0233, -0.2356, -0.0384,  0.3200]], device='cuda:0')


In [82]:
n_states = env.observation_space.shape[0]
n_actions = env.action_space.n

max_episodes = 500
tau = 0.005

episode_durations = []

agent = Agent(n_states, n_actions)

finish_flag = False
for i_episode in range(max_episodes):
    state, info = env.reset()
    #unsqueeze之前state.shape是[4]，之后是[1,4]
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    
    for t in count():
        # the choose_action() might be wrong
        action = agent.choose_action(state)
        observation, reward, terminated, truncated, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        is_done = terminated or truncated
        
        if terminated:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
        
        agent.memorize(state, action, next_state, reward)
        state = next_state
        agent.optimize_model()
        
        target_dict = agent.target_net.state_dict()
        policy_dict = agent.policy_net.state_dict()
        for key in policy_dict:
            target_dict[key] = policy_dict[key]*tau + target_dict[key]*(1-tau)
        agent.target_net.load_state_dict(target_dict)
        
        if is_done:
            episode_durations.append(t+1)
            print(f'steps: {t+1}')
            break
    
        


steps: 9
steps: 19
steps: 14
steps: 20
steps: 29
steps: 14
steps: 18
steps: 10
steps: 34
steps: 16
steps: 12
steps: 18
steps: 11
steps: 15
steps: 22
steps: 31
steps: 57
steps: 19
steps: 13
steps: 18
steps: 16
steps: 14
steps: 13
steps: 10
steps: 27
steps: 10
steps: 12
steps: 9
steps: 25
steps: 11
steps: 18
steps: 11
steps: 16
steps: 13
steps: 15
steps: 13
steps: 10
steps: 11
steps: 11
steps: 15
steps: 13
steps: 9
steps: 16
steps: 12
steps: 9
steps: 12
steps: 13
steps: 10
steps: 10
steps: 14
steps: 9
steps: 20
steps: 15
steps: 9
steps: 11
steps: 12
steps: 19
steps: 10
steps: 11
steps: 9
steps: 10
steps: 9
steps: 20
steps: 9
steps: 11
steps: 10
steps: 9
steps: 14
steps: 22
steps: 10
steps: 9
steps: 14
steps: 11
steps: 12
steps: 11
steps: 12
steps: 15
steps: 11
steps: 9
steps: 9
steps: 13
steps: 11
steps: 13
steps: 11
steps: 15
steps: 15
steps: 9
steps: 16
steps: 14
steps: 15
steps: 21
steps: 13
steps: 10
steps: 12
steps: 24
steps: 42
steps: 38
steps: 10
steps: 28
steps: 19
steps: 11
step

KeyboardInterrupt: 