# 基于值函数的算法
值函数算法隐式的学习策略。值函数算法的目标是最大化动作值函数：

$J(\theta)=\mathbb{E}[V^{\pi_{\theta}}(s_0)]$

In [22]:
# import all you want
from schorl_utils.envs import *
from schorl_utils.envs import get_device
from schorl_utils.buffer import ReplayBuffer, replaybatch
from schorl_utils.net import generate_mlpnet, show_net_structure
from schorl_utils.functions import Agent, Train

class DQN(Agent):
    """
    The input of the DQN is a continuous state space and the output is a discrete action.   
    """
    def __init__(self, 
            net,
            nums_action,
            device = get_device(),
            gamma = 0.9,
            epsilon = 0.1,
            lr = 2e-3,
            optim = torch.optim.Adam,
            loss = torch.nn.functional.mse_loss,
            datatype = torch.float,
            update_target_every = 5,
            ) -> None:

        self.update_target_every = update_target_every
        self.q_net = net.to(device)
        self.target_q_net = net.to(device)
        self.device = device
        self.gamma = gamma
        self.nums_action = nums_action
        self.epsilon = epsilon
        self.count = 0
        self.type = datatype
        self.loss = loss
        self.optim = optim(self.q_net.parameters(), lr)
    
    def __call__(self, state:np.array):
        """
        return action
        """
        if random.random() < self.epsilon:
            return np.random.randint(self.nums_action)
        else:
            state = torch.tensor(np.array([state]), dtype=self.type).to(self.device)
            return self.q_net(state).argmax().item()
    
    def update(self, batch:replaybatch):
        """
        return loss if you want to record
        """
        state_batch = torch.tensor(batch.states, dtype=self.type).to(self.device)
        action_batch = torch.tensor(batch.actions).view(-1,1).to(self.device)
        reward_batch = torch.tensor(batch.rewards, dtype=self.type).view(-1,1).to(self.device)
        next_state_batch = torch.tensor(batch.next_states, dtype=self.type).to(self.device)
        done_batch = torch.tensor(batch.dones, dtype=self.type).view(-1,1).to(self.device)

        thisQ = self.q_net(state_batch).gather(1, action_batch) # get q value
        next_max = self.target_q_net(next_state_batch).max(1)[0].view(-1, 1)
        nextQ = reward_batch + self.gamma * next_max * ( 1 - done_batch)

        self.optim.zero_grad()
        batch_loss = torch.mean(self.loss(thisQ, nextQ))
        batch_loss.backward()
        self.optim.step()

        if self.count % self.update_target_every == 0:
            self.target_q_net.load_state_dict(self.q_net.state_dict())
        self.count+=1
        
        return batch_loss


In [20]:
class DqnTrain(Train):
    
    def __init__(self, env, agent, replaybuffer, batchsize, tblogpath) -> None:
        super().__init__(env=env, agent=agent, tblogpath=tblogpath)
        self.replaybuffer = replaybuffer
        self.batchsize = batchsize

    def run_episode(self):
        """
        rewrite this function to achieve new env interact
        Default:
            default is dqn run in CartPole-v1
        Return:
            {'item', itemvalue}
        """
        done = False
        state = self.env.reset()
        accumulated_reward = 0
        step = 1
        accumulated_loss = 0
        while not done:
            action = self.agent(state)
            next_state,reward,done, *d = self.env.step(action)
            self.replaybuffer.put(state, action, reward, next_state, done)
            state = next_state
            accumulated_reward += reward
            if self.replaybuffer.__len__() > self.batchsize:
                replaybatch  = self.replaybuffer.sample(self.batchsize)
                loss = self.agent.update(replaybatch)
                accumulated_loss += loss
                step += 1
        self.env.close()    
        return {'accumulated_reward': accumulated_reward, 'loss_mean': accumulated_loss/step}


In [21]:
import gym

num_episodes = 1000
batch_size = 64

env = gym.make('CartPole-v1', new_step_api=True)

# generate a mlp module
net = generate_mlpnet(mlp_layers=[env.observation_space.shape[0], 128, env.action_space.n ])

# set a replaybuffer
replaybuffer = ReplayBuffer(1000)

# set a dqn agent
dqnagent = DQN(
    net = net, 
    loss = F.mse_loss,
    nums_action = env.action_space.n
    )

train = DqnTrain(
    env=env,
    agent=dqnagent,
    replaybuffer=replaybuffer,
    batchsize=batch_size,
    tblogpath='./tensorlog'
    )

train.train(num_episodes)


  deprecation(
  deprecation(
  7%|▋         | 72/1000 [00:03<00:46, 19.75it/s, episode=71]


KeyboardInterrupt: 

In [None]:
dqnagent.save_net(dqnagent.q_net ,'./model/dqnCartpole.pt')

In [None]:
!tensorboard --logdir=./tensorlog --port 8123
# open web browser and visit 127.0.0.1:8123

^C


In [None]:
import gym

env = gym.make('CartPole-v1', new_step_api=True,render_mode='human')

net = generate_mlpnet(mlp_layers=[env.observation_space.shape[0], 128, env.action_space.n ])
model = torch.load('./model/dqnCartpole.pt')
net.load_state_dict(model)

done = False
state = env.reset()
while not done:
    env.render()
    state = torch.tensor(np.array([state]), dtype=torch.float).to(get_device())
    action = net(state).argmax().item()
    got = env.step(action=action)
    state = got[0]
    done = got[2]
env.close()