# Discrete

In [None]:
from glob import glob
from gnwrapper import Animation
import gym
import numpy as np

from tensorflow.keras.layers import Activation, Dense, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam

from rl.agents.dqn import DQNAgent
from rl.memory import SequentialMemory
from rl.policy import BoltzmannQPolicy

In [None]:
# Constants.
ENV_NAME = "CartPole-v0"
SEED = 123
STEPS = int(1e3)
VISUALIZE = True
WEIGHTS = f"../data/{ENV_NAME}/weights.h5f"

In [None]:
# Build environment.
env = Animation(gym.make(ENV_NAME))
np.random.seed(SEED)
env.seed(SEED)
nb_actions = env.action_space.n

In [None]:
# Build model.
model = Sequential()

model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
model.add(Dense(16))
model.add(Activation("relu"))
model.add(Dense(16))
model.add(Activation("relu"))
model.add(Dense(16))
model.add(Activation("relu"))
model.add(Dense(nb_actions))
model.add(Activation("linear"))

model.summary()

In [None]:
# Build agent.
memory = SequentialMemory(limit=STEPS, window_length=1)
policy = BoltzmannQPolicy()

agent = DQNAgent(
    model=model,
    nb_actions=nb_actions,
    memory=memory,
    nb_steps_warmup=10,
    target_model_update=1e-2,
    policy=policy,
)

agent.compile(Adam(lr=1e-3), metrics=["mae"])

In [None]:
# Load weights if exist.
if glob(WEIGHTS.replace(".h5f", "") + "*"):
    print("Loading weights...")
    agent.load_weights(WEIGHTS)

In [None]:
# Train.
env = Animation(gym.make(ENV_NAME))
agent.fit(env, nb_steps=STEPS, visualize=VISUALIZE, verbose=2)

In [None]:
# Save weights.
agent.save_weights(WEIGHTS, overwrite=True)

In [None]:
# Evaluate.
env = Animation(gym.make(ENV_NAME))
agent.test(env, nb_episodes=5, visualize=VISUALIZE)