In [1]:
from ff_environment import ForceField
from agent import Agent
from collections import deque
import numpy as np
import torch

In [2]:
env = ForceField()
env_info = env.reset()

In [3]:
# size of each action
action_size = env.action_size
print('Size of each action:', action_size)

# examine the state space 
state = env_info.state
state_size = len(state)
print('The agent observes a state with length: {}'.format(state_size))
print('The starting state looks like:', state)

Size of each action: 2
The agent observes a state with length: 4
The starting state looks like: [0.5 1.  0.  0. ]


In [4]:
# Instantiate the agent:
agent = Agent(state_size, action_size, random_seed=2)

In [5]:
# train the agent with ddpg
def ddpg(n_episodes=5000, max_t=1000, print_every=1000):

    scores = []
    trajectories = [] 
    actions_tracker = []
    scores_deque = deque(maxlen=print_every)
    
    for i_episode in range(n_episodes):
        env_info = env.reset()
        state = env_info.state        # current state
        score = 0                      # initialize agent scores
        trajectory = [state[:2]]           # initialize trajectory 
        actions = [state[2:]]
        agent.reset()                  # reset noise process for action exploration
        
        for t in range(max_t):
            
            action = agent.act(state)
            
            env_info = env.step(action)               # send action to environment
            next_state = env_info.state               # get next state 
            reward = env_info.reward                  # get reward 
            done = env_info.done                      # see if trial is finished
            
            agent.step(state, action, reward, next_state, done)
            
            score += reward                         # update the score (for each agent)
            state = next_state                               # enter next states
            trajectory.append(env_info.pos)
            actions.append(action)
            
            if done:
                break

        scores_deque.append(np.mean(score))
        scores.append(np.mean(score))
        trajectories.append(trajectory)
        actions_tracker.append(actions)
        
        print('\rEpisode {} \tAverage Reward: {:.2f}'.format(i_episode, np.mean(scores_deque)), end="")
        
        if i_episode % print_every == 0:
            torch.save(agent.actor_local.state_dict(), 'actor_model.pth')
            torch.save(agent.critic_local.state_dict(), 'critic_model.pth')
            print('\rEpisode {} \tAverage Reward: {:.2f}'.format(i_episode, np.mean(scores_deque)))
    
        if np.mean(scores_deque) >= 0.07:
            print('\nEnvironment solved in {:d} episodes!\t Average Score: {:.2f}'.format(i_episode, np.mean(scores_deque)))
            torch.save(agent.actor_local.state_dict(), 'actor_solved.pth')
            torch.save(agent.critic_local.state_dict(), 'critic_solved.pth')
            break
            
    return scores, trajectories, actions_tracker

scores, trajectories, actions_tracker = ddpg()

Episode 0 	Average Reward: -1.69Episode 0 	Average Reward: -1.69
Episode 1 	Average Reward: -4.14Episode 2 	Average Reward: -4.36Episode 3 	Average Reward: -3.99Episode 4 	Average Reward: -3.52Episode 5 	Average Reward: -3.20Episode 6 	Average Reward: -2.92Episode 7 	Average Reward: -2.71Episode 8 	Average Reward: -2.55Episode 9 	Average Reward: -2.41



Episode 1000 	Average Reward: -1.70
Episode 2000 	Average Reward: -2.48
Episode 3000 	Average Reward: -1.67
Episode 4000 	Average Reward: -0.90
Episode 4999 	Average Reward: -3.23

In [7]:
trajectories[-2]

[array([0.5, 1. ]),
 (0.42416413128376007, 0.8014600425958633),
 (0.2995334193110466, 0.49080153554677963),
 (0.1393531896173954, 0.11065014824271202),
 (-0.04938287101686001, -0.3173066843301058),
 (-0.2623219592496753, -0.7812156463041902),
 (-0.4961089021526277, -1.2737566572614014),
 (-0.7477115679066628, -1.7896811387035996),
 (-1.0142132559558377, -2.324729488347657),
 (-1.292814779735636, -2.875206359254662),
 (-1.5809119211917277, -3.437854487710865),
 (-1.876179263854283, -4.0098453109530965),
 (-2.176629543733725, -4.588797760261514),
 (-2.480638630902831, -5.172784647471417),
 (-2.786939021307262, -5.760313744905943),
 (-3.094589090028421, -6.3502847679274055),
 (-3.4029277516797265, -6.941930345089077),
 (-3.7115230313636403, -7.534750293389379),
 (-4.0201210330773165, -8.128447552867442),
 (-4.328599696400204, -8.722871328566328),
 (-4.636929924310465, -9.3179700946678),
 (-4.945144285768876, -9.91375516235179),
 (-5.253313141569798, -10.510274059471222),
 (-5.561527119331