In [None]:
from datetime import datetime as dt
import gym
import matplotlib.pyplot as plt
from pathlib import Path
import torch
from torch import nn

from ddpg_agent import Agent, LR_ACTOR, LR_CRITIC, DEVICE

In [None]:
env_name = 'CarRacing-v0'
env = gym.make(env_name, verbose=0)
action_dim = env.action_space.shape[0]

In [None]:
def train_agent(agent=None,
                n_episodes=1_000,
                n_timesteps=1_000,
                solved_reward=900,
                log_interval=100,
                warm_start_from=None,
                warm_start_seed=42,
                render=False):      # Rendering is laggy.
    
    ckpt_dirname = dt.now().strftime('%d-%m-%Y_%H:%M:%S')
    ckpt_dirpath = Path('tmp', ckpt_dirname)
    ckpt_dirpath.mkdir(parents=True)
    print(f'Persisting checkpoints to: {ckpt_dirpath}')
    
    if warm_start_from:
        warm_start_dirpath = Path('tmp', warm_start_from)
        
        agent = Agent(action_dim=action_dim, seed=warm_start_seed)
        agent.actor_local.load_state_dict(torch.load(Path(warm_start_dirpath, 'checkpoint_actor.pth')))
        agent.critic_local.load_state_dict(torch.load(Path(warm_start_dirpath, 'checkpoint_critic.pth')))
        
        reward_history = list(pd.read_csv(Path(warm_start_dirpath, 'history.csv'), index_col=0).values.ravel())
    else:
        assert agent is not None
        reward_history = []
    
    for episode in range(len(reward_history)+1, n_episodes+1):
        agent.reset()
        state = env.reset()
        ep_reward = 0
        for timestep in range(n_timesteps):
            action = agent.act(state)
            next_state, reward, is_done, _ = env.step(action)
            
            agent.step(state, action, reward, next_state, is_done)
            state = next_state
            ep_reward += reward
            
            if render:
                env.render()
            if is_done:
                break
                
        reward_history.append(ep_reward)
        with open(Path(ckpt_dirpath, 'history.csv'), 'a') as history_fl:
            history_fl.write('{},{}\n'.format(episode, ep_reward))

        torch.save(agent.actor_local.state_dict(), Path(ckpt_dirpath, 'checkpoint_actor.pth'))
        torch.save(agent.critic_local.state_dict(), Path(ckpt_dirpath, 'checkpoint_critic.pth'))

        if sum(reward_history[-log_interval:]) > (log_interval * solved_reward):
            print('Solved!')
            break

        print(f'Episode #{episode}\t'
              f'Current episode reward: {round(ep_reward)}\t'
              f'Running average reward (previous {min(log_interval, len(reward_history))} episodes): '
              f'{round(sum(reward_history[-log_interval:]) / min(log_interval, len(reward_history)))}',
              end='\r')

    return reward_history

### Train a DDPG agent

Run the code cells below to train the agent.

In [None]:
agent = Agent(action_dim=action_dim, seed=42)

In [None]:
history = train_agent(agent)

In [None]:
fig, ax = plt.subplots(figsize=(12, 8))
plt.plot(range(1, len(history)+1), history)
plt.ylabel('Reward')
plt.xlabel('Episode #')
plt.show()