## Training a DQN

First we train an agent at CartPole: given a representation of the state (cart position, cart velocity, pole position, pole velocity) at every time step, move left and right to keep the pole upright. The agent receives 1 reward as long as the pole is "alive" i.e. above a certain angle with the horizontal.

In [1]:
import gym

def evaluate(env_name, agent, history_length=4, epsilon=0.1):
    env = gym.make(env_name)
    rewards = []
    for _ in range(10):
        state = first_state(env, env_name, args.history_length)
        ep_reward = 0
        done = False
        frame = 0
        while not done:
            if random.random() < epsilon:
                action = env.action_space.sample()
            else:
                action = agent.get_action(state)
            state_next, reward, done = process_action(
                env, env_name, action, history_length=history_length, prev_image=state
            )
            env.render()
            ep_reward += reward
            state = state_next
            frame += args.history_length
        rewards.append(ep_reward)
        print(ep_reward, frame)
    
    print(sum(rewards) / len(rewards))

We load a trained model from memory:

In [9]:
import dqn
from dqn.agent import *
from dqn.replay import *
from dqn.run import *

# CODE TO TRAIN AGENT
# cartpole_agent = train(
#     env,
#     env_name=env_name
#     history_length=1,
#     num_episodes=25_000,
#     minibatch_size=32,
#     exp_buffer_size=1_000_000,
#     epsilon_init=1.0,
#     epsilon_final=0.01,
#     epsilon_final_frame=500_000,
#     replay_start_frame=50_000,
#     q_target_update_freq=1_000,
#     learning_rate=1e-4,
#     momentum=0.95,
#     discount_factor=0.99,
#     save_every=10,
#     dirname=""
# )

env_name = "CartPole-v1"
dirname = "runs/20211104_195309_CartPole-v1_1_25000_32_1000000_1_" \
          "0.01_500000_50000_1000_0.0001_0.95_0.99"

cartpole_agent = DQNAgent(env_name)
cartpole_agent.load_networks(dirname, checkpoint=21900)
evaluate(env_name, cartpole_agent, history_length=1)

ModuleNotFoundError: No module named 'agent'