In [327]:
from torch import nn
import torch.nn.functional as F
from collections import deque
import copy
from torch.optim import Adam
import torch

In [328]:
class MLP(nn.Module):
    
    def __init__(self,n_state,n_action,hidden_dim=200):
        super().__init__()
        self.fc1 = nn.Linear(n_state,hidden_dim)
        self.fc2 = nn.Linear(hidden_dim,hidden_dim)
        self.fc3 = nn.Linear(hidden_dim,n_action)
    
    
    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [329]:
class Memory:
    
    def __init__(self):
        self.buffer = deque(maxlen=200)
    
    def push(self,transitions):
        self.buffer.append(transitions)
    
    def sample(self):
        return zip(*self.buffer)
    
    def clear(self):
        self.buffer.clear()
    
    def __len__(self):
        return len(self.buffer)

In [330]:
class Agent:
    
    def __init__(self,n_states,n_actions):
        self.qnet = MLP(n_states,n_actions)
        self.target_net = MLP(n_states,n_actions)
#         for p1,p2 in zip(self.qnet.parameters(),self.target_net.parameters()):
#             p2.data.copy_(p1.data)
        
        self.optimizer = Adam(self.qnet.parameters())
        self.memory = Memory()
        self.n_actions = n_actions
    
    def sample_action(self,state):
        if np.random.random() < 0.1:
            return np.random.choice(n_actions)
        else:
            values = self.qnet(state)
            return values.max(1)[1].item()
    
    @torch.no_grad()
    def predict_action(self,state):
        values = self.qnet(state)
        return values.max(1)[1].item()
    
    def update(self):
        if len(self.memory) != 200:
            return
        states,actions,rewards,next_states,dones = self.memory.sample()
        states = torch.tensor(states) # n n_states
        actions = torch.tensor(actions).view(-1,1) # n 
        next_states = torch.tensor(next_states)
        dones = torch.tensor(np.float32(dones))
        rewards = torch.tensor(rewards) # n 1
        qvalues = self.qnet(states) # n n_actions
        qvalues = qvalues.gather(1,actions)
        next_qvalues = self.target_net(next_states) # n n_actions
        target_values = rewards + 0.95 * next_qvalues.max(1)[0].detach()*(1-dones) # max(1)[1]: n values
        loss = F.mse_loss(qvalues.view(-1),target_values.view(-1))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

In [331]:
import gym
import numpy as np
np.bool8 = np.bool_

In [332]:
env_name = 'CartPole-v0'
env = gym.make(env_name)

In [333]:
n_states = env.observation_space.shape[0]
n_actions = env.action_space.n

In [334]:
n_states,n_actions

(4, 2)

In [335]:
agent = Agent(n_states,n_actions)

In [336]:
for epoch in range(1000):
    
    state,_ = env.reset()
    
    for i in range(200):
        action = agent.sample_action(torch.tensor(state).unsqueeze(0))
        next_state, reward, done, _,_ = env.step(action)  # 更新环境，返回transition
        agent.memory.push((state, action, reward,next_state, done))  # 保存transition
        state = next_state
        agent.update()
        if done:
            break
    
    if (epoch+1) % 100 == 0:
        for p1,p2 in zip(agent.qnet.parameters(),agent.target_net.parameters()):
            p2.data.copy_(p1.data)
    
    
    ### test
    if epoch % 100 == 0:
        rewards = 0
        state,_ = env.reset()
        for i in range(200):
            action = agent.predict_action(torch.tensor(state).unsqueeze(0))
            next_state, reward, done, _,_ = env.step(action)  # 更新环境，返回transition
            rewards += reward
            state = next_state
            if done:
                break
        print(f'epoch: {epoch} : reward: {rewards}')

epoch: 0 : reward: 10.0
epoch: 100 : reward: 10.0
epoch: 200 : reward: 20.0
epoch: 300 : reward: 10.0
epoch: 400 : reward: 200.0
epoch: 500 : reward: 95.0
epoch: 600 : reward: 11.0
epoch: 700 : reward: 32.0
epoch: 800 : reward: 20.0
epoch: 900 : reward: 195.0


In [337]:
np.random.choice([1,3,4,1])

np.int64(1)

In [338]:
qq = torch.randn(4,2)

In [339]:
qq

tensor([[ 0.5260, -0.9328],
        [ 1.6860, -1.6825],
        [-1.6531,  0.9750],
        [ 1.0478, -0.9798]])

In [340]:
indexs = torch.LongTensor(
    [
        [1],
        [0],
        [1],
        [1]
    ]
)

In [341]:
indexs

tensor([[1],
        [0],
        [1],
        [1]])

In [342]:
torch.gather(qq,1,indexs)

tensor([[-0.9328],
        [ 1.6860],
        [ 0.9750],
        [-0.9798]])

In [343]:
qq.gather(1,indexs)

tensor([[-0.9328],
        [ 1.6860],
        [ 0.9750],
        [-0.9798]])