In [None]:
from collections import deque
from ipywidgets import interact
from IPython.display import display
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm
from xvfbwrapper import Xvfb
import atexit
import gym
import numpy as np
import os
import pandas as pd
import PIL.Image
import random
import tensorflow as tf

In [None]:
if globals().get('virtual_display') is None or 'DISPLAY' not in os.environ:
    try:
        virtual_display = Xvfb()
        virtual_display.start()
    except:
        virtual_display = None
    atexit.register(virtual_display.stop)
    print('Started xvfb: DISPLAY={!r}'.format(os.environ['DISPLAY']))
else:
    print('Using DISPLAY={!r}'.format(os.environ['DISPLAY']))

In [None]:
def display_np_image(np_image):
    display(PIL.Image.fromarray(np_image, 'RGB'))

In [None]:
env = gym.make('CartPole-v1')

In [None]:
env.reset()

In [None]:
GAMMA = 0.95
LEARNING_RATE = 0.001
MEMORY_SIZE = 1000000
N_ACTIONS = env.action_space.n
N_RUNS = 1
MAX_TIME = 200
BATCH_SIZE = 200

EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.01
EXPLORATION_DECAY = 0.998


def select_action(model, state, exploration_rate):
    if np.random.rand() < exploration_rate:
        action = random.randrange(N_ACTIONS)
    else:
        q_values = model.predict(state[np.newaxis,:])
        #print(q_values)
        action = np.argmax(q_values[0])
    return action


def get_updated_q_value(action, reward, done, q_value1, q_value2):
    if done:
        q_update = -10
    else:
        q_update = reward + GAMMA * np.argmax(q_value2)
    new_q_value = q_value1.copy()
    new_q_value[action] = q_update
    return new_q_value


def replay_memories(model, memories, exploration_rate):
    if len(memories) >= BATCH_SIZE:
        memory_batch = random.sample(memories, BATCH_SIZE)
        states = [x['state'] for x in memory_batch]
        next_states = [x['next_state'] for x in memory_batch]
        predictions = model.predict(np.vstack((states, next_states)))
        q_values1 = predictions[:len(states)]
        q_values2 = predictions[len(states):]
        new_q_values = []
        for memory, q_value1, q_value2 in zip(memory_batch, q_values1, q_values2):
            action = memory['action']
            reward = memory['reward']
            done = memory['done']
            new_q_values += [get_updated_q_value(action, reward, done, q_value1, q_value2)]
        model.fit(np.array(states), np.array(new_q_values), verbose=0)
        exploration_rate *= EXPLORATION_DECAY
        exploration_rate = max(EXPLORATION_MIN, exploration_rate)
    return memories, exploration_rate


def run(env, model, memories, exploration_rate):
    state = env.reset()
    states = [state]
    rewards = [None]
    dones = [False]
    actions = []
    for i in tqdm(range(MAX_TIME)):
        action = select_action(model, state, exploration_rate)
        actions.append(action)
        next_state, reward, done, info = env.step(action)
        memories.append({
            'state': state,
            'action': action,
            'reward': reward,
            'next_state': next_state,
            'done': done
        })
        state = next_state
        states.append(state)
        rewards.append(reward)
        dones.append(done)
        if done:
            break
        memories, exploration_rate = replay_memories(model, memories, exploration_rate)
    actions.append(None)
    return states, actions, rewards, dones, memories, exploration_rate


def show_scores(states, score_smoothing=20):
    scores = [len(x) for x in states]
    scores_smoothed = np.convolve(scores, np.ones((score_smoothing,)) / score_smoothing, mode='valid')
    plt.plot(scores_smoothed)

In [None]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(20, input_shape=(env.observation_space.shape[0],), activation='relu'),
    tf.keras.layers.Dense(20, activation='relu'),
    tf.keras.layers.Dense(env.action_space.n, activation='linear'),
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss='mse',
    metrics=['accuracy']
)

memories = deque(maxlen=MEMORY_SIZE)
exploration_rate = EXPLORATION_MAX
states = []
actions = []
rewards = []
dones = []
run_count = 0

In [None]:
for i in range(30):
    print(f'Trial {run_count}   (exploration rate: {exploration_rate})')
    (
        more_states,
        more_actions,
        more_rewards,
        more_dones,
        memories,
        exploration_rate
    ) = run(env, model, memories, exploration_rate)
    print(more_actions)
    states += [more_states]
    actions += [more_actions]
    rewards += [more_rewards]
    dones += [more_dones]
    run_count += 1

show_scores(states)

In [None]:
show_scores(states)

In [None]:
def render(state):
    env.env.state = state
    return env.render(mode='rgb_array')


@interact(
    trial=(0, len(states) - 1),
    time=(0, MAX_TIME - 1),
)
def f(trial=1105, time=0):
    if time < len(states[trial]):
        display_np_image(render(states[trial][time]))

In [None]:
from jupyter_renderer_widget import Renderer

trial = 980
trial = min(len(states) - 1, trial)
trial_states = states[trial]
display(Renderer(lambda t: render(trial_states[t]), len(trial_states) - 1))

In [None]:
if False:
    observations = [env.reset()]
    rewards = [None]
    dones = [False]
    for i in range(100):
        action = i % 2  # FIXME
        observation, reward, done, info = env.step(action)
        observations.append(observation)
        rewards.append(reward)
        dones.append(done)
        if done:
            break

In [None]:
if False:
    with open('out.json', 'w') as f:
        f.write(json.dumps({
            'actions': actions,
            'rewards': rewards,
            'states': states,

        }))

In [None]:
rows = []
for trial in range(len(states)):
    for step in range(len(states[trial])):
        step_state = states[trial][step]
        row = {
            'trial': trial,
            'step': step,
            'reward': rewards[trial][step],
            'done': dones[trial][step],
        }
        for k, state in enumerate(states[trial][step]):
            row[f'observation{k}'] = state
        rows.append(row)
pd.DataFrame(rows)