In [None]:
import gym
import matplotlib.pyplot as plt
from pathlib import Path
import torch

from ddpg_agent import Agent

In [None]:
env_name = 'CarRacing-v0'
env = gym.make(env_name, verbose=0)
action_dim = env.action_space.shape[0]

In [None]:
def train_agent(agent,
                n_episodes=4_000,
                n_timesteps=1_000,
                solved_reward=900,
                log_interval=100,
                render=False):      # Rendering is laggy.

    reward_history = []
    
    for episode in range(1, n_episodes+1):
        agent.reset()
        state = env.reset()
        ep_reward = 0
        for timestep in range(n_timesteps):
            action = agent.act(state)
            next_state, reward, is_done, _ = env.step(action)
            
            agent.step(state, action, reward, next_state, is_done)
            state = next_state
            ep_reward += reward
            
            if render:
                env.render()
            if is_done:
                break
                
        reward_history.append(ep_reward)

        torch.save(agent.actor_local.state_dict(), Path('models', 'checkpoint_actor.pth'))
        torch.save(agent.critic_local.state_dict(), Path('models', 'checkpoint_critic.pth'))

        if sum(reward_history[-log_interval:]) > (log_interval * solved_reward):
            print('Solved!')
            break

        print(f'Episode #{episode}\t'
              f'Current episode reward: {round(ep_reward)}\t'
              f'Running average reward (previous {min(log_interval, len(reward_history))} episodes): '
              f'{round(sum(reward_history[-log_interval:]) / min(log_interval, len(reward_history)))}',
              end='\r')

    return reward_history

### Train a DDPG agent

Run the code cells below to train the agent.

In [None]:
agent = Agent(action_dim=action_dim, seed=42)

In [None]:
history = train_agent(agent)

In [None]:
fig, ax = plt.subplots(figsize=(12, 8))
plt.plot(range(1, len(history)+1), history)
plt.ylabel('Reward')
plt.xlabel('Episode #')
plt.show()