[source](https://github.com/sfujim/TD3/blob/master/DDPG.py)<br>
[background](https://spinningup.openai.com/en/latest/algorithms/ddpg.html#background)

In [19]:
%reload_ext autoreload
%autoreload 2
import torch
import torch.nn.functional as F 
import random
import numpy as np
from EXITrl.trainer import Trainer
from EXITrl.helpers import get_simple_model, get_state_action_shape_from_env, ExperienceReplay2, update_params, device, convert_to_tensor
from EXITrl.nn_wrapper import NNWrapper
import gym
import copy

![image](../media/DDPG.svg)

In [20]:
class Actor(torch.nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()

        self.l1 = torch.nn.Linear(state_dim, 400)
        self.l2 = torch.nn.Linear(400, 300)
        self.l3 = torch.nn.Linear(300, action_dim)

        self.max_action = max_action


    def forward(self, state):
        a = F.relu(self.l1(state))
        a = F.relu(self.l2(a))
        return self.max_action * torch.tanh(self.l3(a))


class Critic(torch.nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()

        self.l1 = torch.nn.Linear(state_dim, 400)
        # this combining Action+Value is different from my previous implementation
        self.l2 = torch.nn.Linear(400 + action_dim, 300) 
        self.l3 = torch.nn.Linear(300, 1)


    def forward(self, state, action):
        q = F.relu(self.l1(state))
        q = F.relu(self.l2(torch.cat([q, action], 1)))
        return self.l3(q)

In [21]:
try: env.close()
except: pass
env = gym.make('Pendulum-v0')

class DDPG(Trainer):
    def __init__(self, save_name, env, num_episodes, start_timesteps, gamma, tau, explore_noise):
        super().__init__(env, num_episodes)
        self.state_shape, self.action_shape = get_state_action_shape_from_env(env)
        
        # 1. Actors
        max_action = float(env.action_space.high[0])
        self.actor = NNWrapper(
                        model=Actor(self.state_shape, self.action_shape, max_action),
                        lr=1e-4
                     )
        self.actor_target = copy.deepcopy(self.actor)
        
        # 2. Critics
        self.critic = NNWrapper(
                        model=Critic(self.state_shape, self.action_shape),
                        lr=1e-4
                     )
        self.critic_target = copy.deepcopy(self.critic)
        
        # replay
        # TODO add more recall???
        self.experience_replay = ExperienceReplay2(num_experience=1e4, num_recall=1e2)
        # constants
        self.save_name = save_name
        self.start_timesteps = start_timesteps
        self.gamma = gamma
        self.tau = tau
        self.explore_noise = explore_noise
        self.max_action = max_action
        
    def _loop(self, episode) -> int:
        state = env.reset()
        total_reward = 0
        
        # TODO add num steps to param
        for i in range(1000):
            if self.experience_replay.num_current_experience < self.start_timesteps:
                # random unifrom distribuition when start
                action = self.env.action_space.sample()
            else:
                action = (
                    self.actor.forward(state).detach().numpy() +
                    np.random.normal(0, self.max_action * self.explore_noise, size=self.action_shape)
                ).clip(-self.max_action, self.max_action)

            _state, reward, done, _ = env.step(action)
            
            # add more dimension to reward and done for broadcasting after recall
            self.experience_replay.remember(state, action, [reward], _state, [done])

            if self.experience_replay.num_current_experience >= self.start_timesteps:
                self.update()

            state = _state
            total_reward += reward
            if done: return total_reward
        
    def update(self):
        states, actions, rewards, _states, dones = self.experience_replay.recall()
        
        # Compute the target Q value
        target_Q = self.critic_target.forward(_states, self.actor_target.forward(_states))
        target_Q = rewards + (self.gamma * (1-dones) * target_Q).detach()
        
        # Get current Q estimate
        current_Q = self.critic.forward(states, actions)

        # Compute critic loss
        critic_loss = F.mse_loss(current_Q, target_Q)
        self.critic.update(critic_loss)
        
        # Compute actor loss (it's not loss, it's gradient ascent)
        actor_loss = -self.critic.forward(states, self.actor.forward(states)).mean()
        self.critic.update(actor_loss)
        
        # update frozen target
        update_params(self.actor.model, self.actor_target.model, self.tau)
        update_params(self.critic.model, self.critic_target.model, self.tau)
    
    def _save(self, reward):
        torch.save(self.actor.model.state_dict(), self.save_name)
        
    def _load(self):
        checkpoint = torch.load(self.save_name, map_location=device)
        self.actor.model.load_state_dict(checkpoint)
        
    def play(self, num_episode=3):
        self.actor.model.eval()
        super().play(num_episode)
        
    def _policy(self, state):
        return self.actor.forward(state).detach().cpu().numpy()
    
s = DDPG(save_name="checkpoint/Pendulum-v0-DDPG.pth",
         env=env,
         num_episodes=500,
         start_timesteps=2048,
         gamma=.99,
         tau=0.001,
         explore_noise=0.1)
s.train(True)


  result = entry_point.load(False)


[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
Episode 10	Last reward: -1034.26	Average reward: -1057.78 	other{}                    
Episode 20	Last reward: -1154.35	Average reward: -1518.45 	other{}                    
Episode 28	Last reward: -1385.64	Average reward: -1469.12 	other{}

KeyboardInterrupt: 

In [None]:
s.play()