In [15]:
import torch
import torch.autograd
import torch.optim as optim
import torch.nn as nn
from models import *
from utils import *

class DDPGagent:
    def __init__(self, hidden_size=256, actor_learning_rate=1e-4, critic_learning_rate=1e-3, gamma=0.99, tau=1e-2, max_memory_size=50000):
        # Params
        self.num_states = 3
        self.num_actions = 100
        self.gamma = gamma
        self.tau = tau

        # Networks
        self.actor = Actor(self.num_states, hidden_size, self.num_actions)
        self.actor_target = Actor(self.num_states, hidden_size, self.num_actions)
        self.critic = Critic(self.num_states + self.num_actions, hidden_size, self.num_actions)
        self.critic_target = Critic(self.num_states + self.num_actions, hidden_size, self.num_actions)

        for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):
            target_param.data.copy_(param.data)

        for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
            target_param.data.copy_(param.data)
        
        # Training
        self.memory = Memory(max_memory_size)        
        self.critic_criterion  = nn.MSELoss()
        self.actor_optimizer  = optim.Adam(self.actor.parameters(), lr=actor_learning_rate)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_learning_rate)
    
    def get_action(self, state):
        state = Variable(torch.from_numpy(state).float().unsqueeze(0))
        action = self.actor.forward(state)
        action = action.detach().numpy()[0,0]
        return action
    
    def update(self, batch_size):
        states, actions, rewards, next_states, _ = self.memory.sample(batch_size)
        states = torch.FloatTensor(states)
        actions = torch.FloatTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
    
        # Critic loss        
        Qvals = self.critic.forward(states, actions)
        next_actions = self.actor_target.forward(next_states)
        next_Q = self.critic_target.forward(next_states, next_actions.detach())
        Qprime = rewards + self.gamma * next_Q
        critic_loss = self.critic_criterion(Qvals, Qprime)

        # Actor loss
        policy_loss = -self.critic.forward(states, self.actor.forward(states)).mean()
        
        # update networks
        self.actor_optimizer.zero_grad()
        policy_loss.backward()
        self.actor_optimizer.step()

        self.critic_optimizer.zero_grad()
        critic_loss.backward() 
        self.critic_optimizer.step()

        # update target networks 
        for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):
            target_param.data.copy_(param.data * self.tau + target_param.data * (1.0 - self.tau))
       
        for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
            target_param.data.copy_(param.data * self.tau + target_param.data * (1.0 - self.tau))

In [16]:
agent = DDPGagent()

In [17]:
state = np.array([0,1,0])

In [18]:
state = Variable(torch.from_numpy(state)).float().unsqueeze(0)

In [26]:
action = agent.actor.forward(state)

In [24]:
action = action.detach().numpy()[0,0]

In [25]:
action

0.008645908

In [27]:
action.detach().numpy()

array([[ 8.64590798e-03,  5.05254306e-02,  6.88214749e-02,
         6.84536994e-02,  5.63129410e-02, -4.08938117e-02,
         1.62545457e-01, -1.93836302e-01,  1.81970615e-02,
         2.53921360e-01, -6.58428520e-02,  3.52929235e-02,
         3.48545536e-02, -6.70218691e-02, -6.72469884e-02,
        -5.60472999e-03,  6.42209128e-02,  1.35742575e-01,
         1.59838691e-01, -1.93911176e-02, -1.14204340e-01,
        -2.06554905e-01,  4.50151712e-02, -1.55493617e-05,
        -2.77051572e-02, -1.88837592e-02,  8.53493810e-02,
         6.52661547e-02,  1.50352707e-02,  2.02887431e-01,
         1.31309286e-01,  1.07920811e-01, -6.15683086e-02,
         5.57933301e-02,  1.63974136e-01,  8.74372497e-02,
         1.28084257e-01, -3.76115330e-02,  9.03327763e-03,
        -6.62251487e-02,  7.38179833e-02, -1.37300029e-01,
         1.37285382e-01,  2.19473705e-01, -4.14418690e-02,
         5.86240552e-02, -1.89701304e-01,  1.36127705e-02,
         7.91407656e-03, -6.32426888e-02, -5.81539162e-0