## Sloving Needle Master with Twin Delayed DDPG (TD3)
Code modified from https://github.com/nikhilbarhate99/TD3-PyTorch-BipedalWalker-v2 <br>


In [1]:
import numpy as np
import torch
import argparse
import os
import random
from context import needlemaster as nm
from environment import Environment
import utils
import TD3
import math
import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

ModuleNotFoundError: No module named 'context'

## Policy comparison: TD3 V.S. DDPG

In [2]:

##############  TD3  ###############

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        
        self.l1 = nn.Linear(state_dim, 400)
        self.l2 = nn.Linear(400, 300)
        self.l3 = nn.Linear(300, action_dim)
        
        self.max_action = max_action
        
    def forward(self, state):
        a = F.relu(self.l1(state))
        a = F.relu(self.l2(a))
        a = torch.tanh(self.l3(a)) * self.max_action
        return a
        
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        
        self.l1 = nn.Linear(state_dim + action_dim, 400)
        self.l2 = nn.Linear(400, 300)
        self.l3 = nn.Linear(300, 1)
        
    def forward(self, state, action):
        state_action = torch.cat([state, action], 1)
        
        q = F.relu(self.l1(state_action))
        q = F.relu(self.l2(q))
        q = self.l3(q)
        return q
    
class TD3:
    def __init__(self, lr, state_dim, action_dim, max_action):
        
        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
        
        self.critic_1 = Critic(state_dim, action_dim).to(device)
        self.critic_1_target = Critic(state_dim, action_dim).to(device)
        self.critic_1_target.load_state_dict(self.critic_1.state_dict())
        self.critic_1_optimizer = optim.Adam(self.critic_1.parameters(), lr=lr)
        
        self.critic_2 = Critic(state_dim, action_dim).to(device)
        self.critic_2_target = Critic(state_dim, action_dim).to(device)
        self.critic_2_target.load_state_dict(self.critic_2.state_dict())
        self.critic_2_optimizer = optim.Adam(self.critic_2.parameters(), lr=lr)
        
        self.max_action = max_action
    
    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(device)
        return self.actor(state).cpu().data.numpy().flatten()
    
    def update(self, replay_buffer, n_iter, batch_size, gamma, polyak, policy_noise, noise_clip, policy_delay):
        
        for i in range(n_iter):
            # Sample a batch of transitions from replay buffer:
            state, action_, reward, next_state, done = replay_buffer.sample(batch_size)
            state = torch.FloatTensor(state).to(device)
            action = torch.FloatTensor(action_).to(device)
            reward = torch.FloatTensor(reward).reshape((batch_size,1)).to(device)
            next_state = torch.FloatTensor(next_state).to(device)
            done = torch.FloatTensor(done).reshape((batch_size,1)).to(device)
            
            # Select next action according to target policy:
            noise = torch.FloatTensor(action_).data.normal_(0, policy_noise).to(device)
            noise = noise.clamp(-noise_clip, noise_clip)
            next_action = (self.actor_target(next_state) + noise)
            next_action = next_action.clamp(-self.max_action, self.max_action)
            
            # Compute target Q-value:
            target_Q1 = self.critic_1_target(next_state, next_action)
            target_Q2 = self.critic_2_target(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + ((1-done) * gamma * target_Q).detach()
            
            # Optimize Critic 1:
            current_Q1 = self.critic_1(state, action)
            loss_Q1 = F.mse_loss(current_Q1, target_Q)
            self.critic_1_optimizer.zero_grad()
            loss_Q1.backward()
            self.critic_1_optimizer.step()
            
            # Optimize Critic 2:
            current_Q2 = self.critic_2(state, action)
            loss_Q2 = F.mse_loss(current_Q2, target_Q)
            self.critic_2_optimizer.zero_grad()
            loss_Q2.backward()
            self.critic_2_optimizer.step()
            
            # Delayed policy updates:
            if i % policy_delay == 0:
                # Compute actor loss:
                actor_loss = -self.critic_1(state, self.actor(state)).mean()
                
                # Optimize the actor
                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                self.actor_optimizer.step()
                
                # Polyak averaging update:
                for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                    target_param.data.copy_( (polyak * target_param.data) + ((1-polyak) * param.data))
                
                for param, target_param in zip(self.critic_1.parameters(), self.critic_1_target.parameters()):
                    target_param.data.copy_( (polyak * target_param.data) + ((1-polyak) * param.data))
                
                for param, target_param in zip(self.critic_2.parameters(), self.critic_2_target.parameters()):
                    target_param.data.copy_( (polyak * target_param.data) + ((1-polyak) * param.data))
                    
                
    def save(self, directory, name):
        torch.save(self.actor.state_dict(), '%s/%s_actor.pth' % (directory, name))
        torch.save(self.actor_target.state_dict(), '%s/%s_actor_target.pth' % (directory, name))
        
        torch.save(self.critic_1.state_dict(), '%s/%s_crtic_1.pth' % (directory, name))
        torch.save(self.critic_1_target.state_dict(), '%s/%s_critic_1_target.pth' % (directory, name))
        
        torch.save(self.critic_2.state_dict(), '%s/%s_crtic_2.pth' % (directory, name))
        torch.save(self.critic_2_target.state_dict(), '%s/%s_critic_2_target.pth' % (directory, name))
        
    def load(self, directory, name):
        self.actor.load_state_dict(torch.load('%s/%s_actor.pth' % (directory, name), map_location=lambda storage, loc: storage))
        self.actor_target.load_state_dict(torch.load('%s/%s_actor_target.pth' % (directory, name), map_location=lambda storage, loc: storage))
        
        self.critic_1.load_state_dict(torch.load('%s/%s_crtic_1.pth' % (directory, name), map_location=lambda storage, loc: storage))
        self.critic_1_target.load_state_dict(torch.load('%s/%s_critic_1_target.pth' % (directory, name), map_location=lambda storage, loc: storage))
        
        self.critic_2.load_state_dict(torch.load('%s/%s_crtic_2.pth' % (directory, name), map_location=lambda storage, loc: storage))
        self.critic_2_target.load_state_dict(torch.load('%s/%s_critic_2_target.pth' % (directory, name), map_location=lambda storage, loc: storage))
        
        
    def load_actor(self, directory, name):
        self.actor.load_state_dict(torch.load('%s/%s_actor.pth' % (directory, name), map_location=lambda storage, loc: storage))
        self.actor_target.load_state_dict(torch.load('%s/%s_actor_target.pth' % (directory, name), map_location=lambda storage, loc: storage))

In [11]:
def train(env):
    ######### Hyperparameters #########
    env_name = env
    log_interval = 10           # print avg reward after interval
    random_seed = 0
    gamma = 0.99                # discount for future rewards
    batch_size = 100            # num of transitions sampled from replay buffer
    lr = 0.001
    exploration_noise = 0.1 
    polyak = 0.995              # target policy update parameter (1-tau)
    policy_noise = 0.2          # target policy smoothing noise
    noise_clip = 0.5
    policy_delay = 2            # delayed policy updates parameter
    max_episodes = 1000         # max num of episodes
    max_timesteps = 2000        # max timesteps in one episode
    save_every = 100            # model saving interal
    directory = "./preTrained/{}".format(env_name) # save trained models
    filename = "TD3_{}_{}".format(env_name, random_seed)
    render = True
    save_gif = True
    ###################################
    
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    print("state dim: " + str(state_dim))
    action_dim = env.action_space.shape[0]
    print("action dim: " + str(action_dim))
    max_action = float(env.action_space.high[0])
    print("max action: " + str(max_action))
    
    policy = TD3(lr, state_dim, action_dim, max_action)
    replay_buffer = ReplayBuffer()
    
    if random_seed:
        print("Random Seed: {}".format(random_seed))
        env.seed(random_seed)
        torch.manual_seed(random_seed)
        np.random.seed(random_seed)
    
    # logging variables:
    avg_reward = 0
    ep_reward = 0
    log_f = open("log.txt","w+")
    
    # training procedure:
    for episode in range(1, max_episodes+1):
        state = env.reset()
        print("**************************************")
        for t in range(max_timesteps):
            # select action and add exploration noise:
            action = policy.select_action(state)
            action = action + np.random.normal(0, exploration_noise, size=env.action_space.shape[0])
            action = action.clip(env.action_space.low, env.action_space.high)
            
            print("action: " + str(action))
            # take action in env:
            next_state, reward, done, _ = env.step(action)
            print("state: " +str(next_state))
            replay_buffer.add((state, action, reward, next_state, float(done)))
            state = next_state
            
            avg_reward += reward
            ep_reward += reward
                    
            # if episode is done then update policy:
            if done or t==(max_timesteps-1):
                policy.update(replay_buffer, t, batch_size, gamma, polyak, policy_noise, noise_clip, policy_delay)
                break
        
        # logging updates:
        log_f.write('{},{}\n'.format(episode, ep_reward))
        log_f.flush()
        ep_reward = 0
        
        # if avg reward > 300 then save and stop traning:
        if (avg_reward/log_interval) >= 300:
#         if episode % save_every == 0:
            print("########## Model received ###########")
            name = filename
            policy.save(directory, name)
            log_f.close()
            break
        
        if episode > 500:
            policy.save(directory, filename)

        
        # print avg reward every log interval:
        if episode % log_interval == 0:
            avg_reward = int(avg_reward / log_interval)
            print("Episode: {}\tAverage Reward: {}".format(episode, avg_reward))
            avg_reward = 0


In [12]:
### main function
train('LunarLanderContinuous-v2')

state dim: 8
action dim: 2
max action: 1.0
**************************************
action: [-0.15306012 -0.06722448]
state: [ 0.00592918  1.4066474   0.29985452 -0.10779373 -0.00678923 -0.06722295
  0.          0.        ]
action: [-0.11248199  0.18978595]
state: [ 0.00889387  1.4036224   0.29986483 -0.13446108 -0.01014807 -0.06718305
  0.          0.        ]
action: [0.02129641 0.09936495]
state: [ 0.01183596  1.4004643   0.29772872 -0.14039133 -0.01363043 -0.06965341
  0.          0.        ]
action: [-0.11858706  0.22017383]
state: [ 0.01477814  1.3967061   0.2977389  -0.1670644  -0.01711213 -0.0696402
  0.          0.        ]
action: [ 0.14570109 -0.02601942]
state: [ 0.01763401  1.3928577   0.28954697 -0.17109199 -0.02103433 -0.07845116
  0.          0.        ]
action: [-0.20215981 -0.0223011 ]
state: [ 0.02048988  1.3884094   0.2895587  -0.19776554 -0.02495554 -0.07843148
  0.          0.        ]
action: [ 0.01446622 -0.06123382]
state: [ 0.02338037  1.3840187   0.2928935  -0.

**************************************
action: [0.78083513 0.93587538]
state: [ 0.01239023  1.4137601   0.6263441   0.07151929 -0.0163234  -0.18249401
  0.          0.        ]
action: [0.96916881 0.89010238]
state: [ 0.01850691  1.4157971   0.62296265  0.09037569 -0.02749385 -0.22342975
  0.          0.        ]
action: [1.         0.69081621]
state: [ 0.02451372  1.4180379   0.6138553   0.09929822 -0.04051683 -0.26048392
  0.          0.        ]
action: [1.         0.95199673]
state: [ 0.03078384  1.420665    0.64174217  0.11629941 -0.05509819 -0.29165414
  0.          0.        ]
action: [0.9255386  0.80418653]
state: [ 0.03719873  1.4237589   0.6577245   0.13682996 -0.07117268 -0.3215199
  0.          0.        ]
action: [0.92064729 1.        ]
state: [ 0.04380102  1.4271699   0.67832166  0.15064421 -0.08911368 -0.35885367
  0.          0.        ]
action: [0.8059412  0.88798291]
state: [ 0.05039968  1.430614    0.679911    0.15175346 -0.10902648 -0.39829248
  0.          0.      

**************************************
action: [ 1.         -0.92774293]
state: [ 2.4358749e-03  1.4116817e+00  1.1106189e-01  2.4062034e-02
 -1.4928260e-03  7.1124133e-04  0.0000000e+00  0.0000000e+00]
action: [ 1.         -0.95723286]
state: [3.48033896e-03 1.41228175e+00 1.02476396e-01 2.66721379e-02
 4.52107226e-04 3.89021225e-02 0.00000000e+00 0.00000000e+00]
action: [ 0.97367009 -0.93177568]
state: [0.00427399 1.4134429  0.07590602 0.05160009 0.00387908 0.06854589
 0.         0.        ]
action: [ 0.76940961 -1.        ]
state: [0.00480804 1.4153643  0.04821602 0.0853727  0.00901068 0.10264155
 0.         0.        ]
action: [ 0.82307741 -1.        ]
state: [0.00525045 1.4171584  0.03705816 0.07968353 0.01610447 0.1418889
 0.         0.        ]
action: [ 1.         -0.95296104]
state: [0.00571346 1.4192518  0.03701678 0.09290661 0.02529338 0.18379505
 0.         0.        ]
action: [ 1. -1.]
state: [0.00595245 1.4213197  0.01286334 0.09169495 0.03621059 0.21836415
 0.         0.

**************************************
action: [ 0.99915859 -1.        ]
state: [ 1.4633179e-03  1.3947403e+00  5.8817416e-02 -3.5889378e-01
 -3.8150602e-04  1.2373556e-02  0.0000000e+00  0.0000000e+00]
action: [ 0.93163331 -1.        ]
state: [ 0.00183325  1.3868587   0.03462438 -0.35029906  0.00195122  0.04665908
  0.          0.        ]
action: [ 0.86930358 -1.        ]
state: [ 0.0020566   1.3789763   0.01809349 -0.35034016  0.00614726  0.08392844
  0.          0.        ]
action: [ 1. -1.]
state: [ 0.00216217  1.3715807   0.00437814 -0.32873428  0.0122559   0.1221841
  0.          0.        ]
action: [ 1.         -0.92500906]
state: [ 0.00212421  1.3649428  -0.0116981  -0.29509717  0.02007074  0.15631124
  0.          0.        ]
action: [ 0.99489184 -0.92397427]
state: [ 0.001859    1.3589644  -0.03598667 -0.2658647   0.02943035  0.18720946
  0.          0.        ]
action: [ 0.82510216 -0.95932856]
state: [ 1.3228416e-03  1.3535949e+00 -6.4698569e-02 -2.3889525e-01
  4.0393006e

**************************************
action: [ 0.99924416 -0.96153199]
state: [-0.00996733  1.4280262  -0.5066221   0.40212542  0.01346652  0.15382396
  0.          0.        ]
action: [ 1. -1.]
state: [-0.01517267  1.4375738  -0.53005785  0.42421138  0.02288656  0.1884181
  0.          0.        ]
action: [ 1. -1.]
state: [-0.02048693  1.4478238  -0.5429299   0.45534578  0.03426344  0.22755861
  0.          0.        ]
action: [ 0.95264914 -0.99325049]
state: [-0.02590008  1.4585686  -0.5548128   0.47717828  0.04761147  0.26698557
  0.          0.        ]
action: [ 1.         -0.86810985]
state: [-0.03143377  1.4700856  -0.56856674  0.5113183   0.06265631  0.3009243
  0.          0.        ]
action: [ 0.98807531 -0.96439846]
state: [-0.03690539  1.4817781  -0.5645988   0.5188414   0.07994164  0.34573862
  0.          0.        ]
action: [ 0.98333135 -0.81936821]
state: [-0.04248018  1.4939774  -0.57657015  0.54106003  0.09889635  0.3791294
  0.          0.        ]
action: [ 1.    

**************************************
action: [ 0.91454327 -1.        ]
state: [ 0.00980349  1.3929052   0.48737544 -0.39053383 -0.00950672 -0.07514997
  0.          0.        ]
action: [ 0.88370349 -0.96959787]
state: [ 0.0144599   1.3842386   0.46770042 -0.3851975  -0.01153984 -0.04066608
  0.          0.        ]
action: [ 0.93747648 -0.98516231]
state: [ 0.01885834  1.3761783   0.4402875  -0.3582378  -0.01197333 -0.00867071
  0.          0.        ]
action: [ 0.82474004 -1.        ]
state: [ 0.02323246  1.3682984   0.4357646  -0.3501995  -0.01034986  0.03247207
  0.          0.        ]
action: [ 1.         -0.84588234]
state: [ 0.02772989  1.3607569   0.44599628 -0.33516154 -0.0066339   0.07432584
  0.          0.        ]
action: [ 1. -1.]
state: [ 3.2336138e-02  1.3535120e+00  4.5451108e-01 -3.2197776e-01
 -5.9747411e-04  1.2073980e-01  0.0000000e+00  0.0000000e+00]
action: [ 1.         -0.91537224]
state: [ 0.03693018  1.3472351   0.45134383 -0.2789926   0.00737436  0.15945147

**************************************
action: [ 1.         -0.89313789]
state: [ 0.01548557  1.4002597   0.7697232  -0.22567934 -0.01665097 -0.15090446
  0.          0.        ]
action: [ 0.99419917 -0.88574314]
state: [ 0.02311745  1.3954005   0.768884   -0.21603893 -0.02228251 -0.11264108
  0.          0.        ]
action: [ 0.90747903 -0.85418834]
state: [ 0.03058605  1.3907809   0.7510732  -0.20538987 -0.02645952 -0.08354807
  0.          0.        ]
action: [ 1.        -0.9443197]
state: [ 0.03782511  1.3867066   0.7266065  -0.18112612 -0.02912142 -0.0532429
  0.          0.        ]
action: [ 1.         -0.89707205]
state: [ 0.04520626  1.3835132   0.73861194 -0.14193621 -0.02961147 -0.00980236
  0.          0.        ]
action: [ 0.94799134 -0.84742628]
state: [ 0.05250568  1.381164    0.72883296 -0.10439005 -0.02851894  0.02185297
  0.          0.        ]
action: [ 0.95211261 -1.        ]
state: [ 0.05962544  1.3795574   0.7091004  -0.0713532  -0.02567046  0.05697465
  0.      

**************************************
action: [ 0.96840906 -0.64361303]
state: [-0.01152382  1.3962237  -0.5802356  -0.32441947  0.014986    0.16506878
  0.          0.        ]
action: [ 0.96252934 -0.53094268]
state: [-0.01717415  1.3897314  -0.574831   -0.28866938  0.02467632  0.19382481
  0.          0.        ]
action: [ 1.         -0.43195538]
state: [-0.02296057  1.3835405  -0.5878438  -0.2753344   0.03378499  0.18219025
  0.          0.        ]
action: [ 0.97695001 -0.51693693]
state: [-0.02882185  1.3776106  -0.5963203  -0.26381382  0.04386954  0.20170924
  0.          0.        ]
action: [ 1.         -0.35929596]
state: [-0.0348526   1.3718206  -0.6125478  -0.25762948  0.05324728  0.18757215
  0.          0.        ]
action: [ 0.97055023 -0.34104798]
state: [-0.04101858  1.3663639  -0.62558275 -0.24286151  0.06213839  0.17783882
  0.          0.        ]
action: [ 0.93174268 -0.21841682]
state: [-0.04731913  1.361588   -0.63863504 -0.21264037  0.07064205  0.17008898
  0.   

**************************************
action: [ 1. -1.]
state: [ 0.00815439  1.3936713   0.41198954 -0.37571752 -0.00710667 -0.04821062
  0.          0.        ]
action: [ 0.98645491 -0.94080714]
state: [ 0.01221189  1.3860375   0.40625072 -0.33928365 -0.00759497 -0.00976697
  0.          0.        ]
action: [ 1.         -0.94584659]
state: [ 0.01602707  1.3792186   0.3804823  -0.30305353 -0.00656754  0.02055038
  0.          0.        ]
action: [ 0.85723822 -1.        ]
state: [ 0.01987495  1.3728144   0.38156548 -0.28461912 -0.00338171  0.06372203
  0.          0.        ]
action: [ 0.9969598 -1.       ]
state: [ 0.0238224   1.3671243   0.3891912  -0.2528885   0.00210988  0.1098422
  0.          0.        ]
action: [ 1. -1.]
state: [ 0.0277997   1.361874    0.3899689  -0.23338467  0.00978522  0.1535208
  0.          0.        ]
action: [ 1.         -0.89101279]
state: [ 0.03176327  1.3568547   0.38667607 -0.22316952  0.01935669  0.19144711
  0.          0.        ]
action: [ 0.82905

KeyboardInterrupt: 

### Store the results as video

In [9]:
from pyvirtualdisplay import Display

virtual_display = Display(visible=0, size=(1400, 900))
virtual_display.start()

import gym
from gym import wrappers

env_name = "LunarLanderContinuous-v2"
random_seed = 0
n_episodes = 3
lr = 0.002
max_timesteps = 2000
render = True
save_gif = False

filename = "TD3_{}_{}".format(env_name, random_seed)
# print(str(filename))
#     filename += '_solved'
directory = "./preTrained/{}".format(env_name)

# env = gym.make("BipedalWalker-v2")
env = gym.make(env_name)
env = wrappers.Monitor(env, "results/"+env_name, force = True)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

policy = TD3(lr, state_dim, action_dim, max_action)

policy.load_actor(directory, filename)


for ep in range(1, n_episodes+1):
    ep_reward = 0
    state = env.reset()
    for t in range(max_timesteps):
        action = policy.select_action(state)
        state, reward, done, _ = env.step(action)
        ep_reward += reward
        if render:
            env.render()
            if save_gif:
                img = env.render(mode = 'rgb_array')
                img = Image.fromarray(img)
                img.save('./gif/{}.jpg'.format(t))
        if done:
            break

    print('Episode: {}\tReward: {}'.format(ep, int(ep_reward)))
    ep_reward = 0
    env.close()        


Episode: 1	Reward: 219
Episode: 2	Reward: 271
Episode: 3	Reward: 253


### Play videos

In [5]:
import os
import io
import base64
from IPython.display import display, HTML

def ipython_show_video(path):
    """Show a video at `path` within IPython Notebook
    """
    if not os.path.isfile(path):
        raise NameError("Cannot access: {}".format(path))

    video = io.open(path, 'r+b').read()
    encoded = base64.b64encode(video)

    display(HTML(
        data="""
        <video alt="test" controls>
        <source src="data:video/mp4;base64,{0}" type="video/mp4" />
        </video>
        """.format(encoded.decode('ascii'))
    ))

ipython_show_video("results/BipedalWalker-v2/openaigym.video.0.3265.video000000.mp4")

In [10]:
import os
import io
import base64
from IPython.display import display, HTML

def ipython_show_video(path):
    """Show a video at `path` within IPython Notebook
    """
    if not os.path.isfile(path):
        raise NameError("Cannot access: {}".format(path))

    video = io.open(path, 'r+b').read()
    encoded = base64.b64encode(video)

    display(HTML(
        data="""
        <video alt="test" controls>
        <source src="data:video/mp4;base64,{0}" type="video/mp4" />
        </video>
        """.format(encoded.decode('ascii'))
    ))

ipython_show_video("results/LunarLanderContinuous-v2/openaigym.video.0.3433.video000000.mp4")

### DDPG

In [7]:
def test():
    env_name = "BipedalWalker-v2"
    random_seed = 0
    n_episodes = 3
    lr = 0.002
    max_timesteps = 2000
    render = True
    save_gif = False
    
    filename = "TD3_{}_{}".format(env_name, random_seed)
    print(str(filename))
#     filename += '_solved'
    directory = "./preTrained/{}".format(env_name)
    
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])

    policy = TD3(lr, state_dim, action_dim, max_action)
    
    
    policy.load_actor(directory, filename)
    
    
    for ep in range(1, n_episodes+1):
        ep_reward = 0
        state = env.reset()
        for t in range(max_timesteps):
            action = policy.select_action(state)
            state, reward, done, _ = env.step(action)
            ep_reward += reward
            if render:
                env.render()
                if save_gif:
                    img = env.render(mode = 'rgb_array')
                    img = Image.fromarray(img)
                    img.save('./gif/{}.jpg'.format(t))
            if done:
                break
            
        print('Episode: {}\tReward: {}'.format(ep, int(ep_reward)))
        ep_reward = 0
        env.close()        


In [3]:

################ DDPG ##################

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Re-tuned version of Deep Deterministic Policy Gradients (DDPG)
# Paper: https://arxiv.org/abs/1509.02971


class Actor(nn.Module):
	def __init__(self, state_dim, action_dim, max_action):
		super(Actor, self).__init__()

		self.l1 = nn.Linear(state_dim, 400)
		self.l2 = nn.Linear(400, 300)
		self.l3 = nn.Linear(300, action_dim)
		
		self.max_action = max_action

	
	def forward(self, x):
		x = F.relu(self.l1(x))
		x = F.relu(self.l2(x))
		x = self.max_action * torch.tanh(self.l3(x)) 
		return x 


class Critic(nn.Module):
	def __init__(self, state_dim, action_dim):
		super(Critic, self).__init__()

		self.l1 = nn.Linear(state_dim + action_dim, 400)
		self.l2 = nn.Linear(400, 300)
		self.l3 = nn.Linear(300, 1)


	def forward(self, x, u):
		x = F.relu(self.l1(torch.cat([x, u], 1)))
		x = F.relu(self.l2(x))
		x = self.l3(x)
		return x 


class DDPG(object):
	def __init__(self, state_dim, action_dim, max_action):
		self.actor = Actor(state_dim, action_dim, max_action).to(device)
		self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
		self.actor_target.load_state_dict(self.actor.state_dict())
		self.actor_optimizer = torch.optim.Adam(self.actor.parameters())

		self.critic = Critic(state_dim, action_dim).to(device)
		self.critic_target = Critic(state_dim, action_dim).to(device)
		self.critic_target.load_state_dict(self.critic.state_dict())
		self.critic_optimizer = torch.optim.Adam(self.critic.parameters())		


	def select_action(self, state):
		state = torch.FloatTensor(state.reshape(1, -1)).to(device)
		return self.actor(state).cpu().data.numpy().flatten()


	def train(self, replay_buffer, iterations, batch_size=100, discount=0.99, tau=0.005):

		for it in range(iterations):

			# Sample replay buffer 
			x, y, u, r, d = replay_buffer.sample(batch_size)
			state = torch.FloatTensor(x).to(device)
			action = torch.FloatTensor(u).to(device)
			next_state = torch.FloatTensor(y).to(device)
			done = torch.FloatTensor(1 - d).to(device)
			reward = torch.FloatTensor(r).to(device)

			# Compute the target Q value
			target_Q = self.critic_target(next_state, self.actor_target(next_state))
			target_Q = reward + (done * discount * target_Q).detach()

			# Get current Q estimate
			current_Q = self.critic(state, action)

			# Compute critic loss
			critic_loss = F.mse_loss(current_Q, target_Q)

			# Optimize the critic
			self.critic_optimizer.zero_grad()
			critic_loss.backward()
			self.critic_optimizer.step()

			# Compute actor loss
			actor_loss = -self.critic(state, self.actor(state)).mean()
			
			# Optimize the actor 
			self.actor_optimizer.zero_grad()
			actor_loss.backward()
			self.actor_optimizer.step()

			# Update the frozen target models
			for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
				target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

			for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
				target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)


	def save(self, filename, directory):
		torch.save(self.actor.state_dict(), '%s/%s_actor.pth' % (directory, filename))
		torch.save(self.critic.state_dict(), '%s/%s_critic.pth' % (directory, filename))


	def load(self, filename, directory):
		self.actor.load_state_dict(torch.load('%s/%s_actor.pth' % (directory, filename)))
		self.critic.load_state_dict(torch.load('%s/%s_critic.pth' % (directory, filename)))
