In [1]:
import gym
from torch import nn, optim
import numpy as np
import datetime as dt

In [2]:
from observers import (
    WindowMetricLogger,
    WindowStepMetricLogger,
    StateAnalysisLogger,
    TensorboardScalarLogger,
    TensorboardHistogramLogger
)
from agents import (
    DQNAgent,
    EpsilonDecreasingStrategy
)
from training import (
    QLearningTrainer,
    QLearningContext,
    episode_value_accessor
)
from common import (
    Discretizer,
    Tensorboard
)

In [3]:
%load_ext tensorboard

In [4]:
env = gym.make('CartPole-v1')

In [5]:
BATCH_SIZE=128

model = nn.Sequential(
    nn.Linear(env.observation_space.shape[0], 32),
    nn.ReLU(),
    nn.Linear(32, 32),
    nn.ReLU(),
    nn.Linear(32, 32),
    nn.ReLU(),
    nn.Linear(32, env.action_space.n)
)

agent = DQNAgent(
    env=env, 
    strategy=EpsilonDecreasingStrategy(
        initial_epsilon=1.0,
        min_epsilon=0.1,
        decay=0.01
    ),
    model=model, 
    optimizer=optim.Adam(model.parameters(), lr=0.001),
    loss=nn.MSELoss(), 
    discount=0.99,
    memory_size=20000,
    batch_size=BATCH_SIZE,
    skip=None,
    target_update_frequency=10
)

In [6]:
TENSORBOARD_LOGDIR = "./logs/cartpole-v0/4"

In [7]:
def train_observers():
    tb = Tensorboard(TENSORBOARD_LOGDIR + '/' + dt.datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))
    return [
        TensorboardScalarLogger(tb=tb, name='loss', apply=episode_value_accessor('loss', np.mean)),
        TensorboardScalarLogger(tb=tb, name='target', apply=episode_value_accessor('target', np.mean)),
        TensorboardScalarLogger(tb=tb, name='reward', apply=episode_value_accessor('reward', np.sum)),
        TensorboardScalarLogger(tb=tb, name='epsilon', apply=episode_value_accessor('epsilon')),
        TensorboardScalarLogger(tb=tb, name='action', apply=episode_value_accessor('action', np.mean)),
        TensorboardScalarLogger(tb=tb, name='norm', apply=episode_value_accessor('gradient_norm', np.mean)),
        TensorboardHistogramLogger(tb=tb, name='loss_hist', apply=episode_value_accessor('loss'), bins=100),
    ]

In [8]:
trainer = QLearningTrainer(
    env=env, 
    agent=agent,
)

In [9]:
trainer.train(
    epochs=1000,
    observers=train_observers()
)

KeyboardInterrupt: 

In [None]:
!rm -rf logs