In [None]:
from rl.memory import SequentialMemory
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam
from rl.agents.cem import CEMAgent
from rl.agents.dqn import DQNAgent
from rl.policy import BoltzmannQPolicy
from rl.memory import EpisodeParameterMemory
from rl.callbacks import FileLogger, ModelIntervalCheckpoint
from nes_py.wrappers import JoypadSpace
import gym
from rl.core import Processor
from Contra.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT, RIGHT_ONLY
from PIL import Image
import argparse
import json
import matplotlib.pyplot as plt

In [None]:
def visualize_log(filename, figsize=None, output=None):
    with open(filename, 'r') as f:
        data = json.load(f)
    if 'episode' not in data:
        raise ValueError('Log file "{}" does not contain the "episode" key.'.format(filename))
    episodes = data['episode']

    # Get value keys. The x axis is shared and is the number of episodes.
    keys = sorted(list(set(data.keys()).difference(set(['episode']))))

    if figsize is None:
        figsize = (15., 5. * len(keys))
    f, axarr = plt.subplots(len(keys), sharex=True, figsize=figsize)
    for idx, key in enumerate(keys):
        axarr[idx].plot(episodes, data[key])
        axarr[idx].set_ylabel(key)
    plt.xlabel('episodes')
    plt.tight_layout()
    if output is None:
        plt.show()
    else:
        plt.savefig(output)



In [None]:
ENV_NAME = 'Contra-v0'
#env = gym.make('Contra-v0').env
# access the behind-the.scenes dynamics of a specific environment
#env = env.unwrapped

CUSTOM_MOVEMENT =   [  ['NOOP'],
    ['right'],
    ['right', 'A'],
    ['right', 'B'],
    ['right', 'A', 'up'],
    ['right', 'B', 'up'],
    ['right', 'A', 'B', 'up'],
    ['A'],
    ['B'],
    ['A', 'B'],

    ['down', 'A'],
    ['down', 'B'],
    ['down', 'A', 'B'],
    ['up', 'A'],
    ['up', 'A', 'B'],
]

# Get the environment and extract the number of actions.
env = gym.make(ENV_NAME)
env = JoypadSpace(env, CUSTOM_MOVEMENT)
np.random.seed(123)
env.seed(123)
nb_actions = env.action_space.n
print(nb_actions)
print(env.observation_space.shape)
obs_dim = env.observation_space.shape[0]


# Next, we build a very simple model.
model = Sequential()
model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(nb_actions))
model.add(Activation('linear'))
print(model.summary())

# Finally, we configure and compile our agent. You can use every built-in Keras optimizer and
# even the metrics!
memory = SequentialMemory(limit=50000, window_length=1)
policy = BoltzmannQPolicy()
dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=10,
               target_model_update=1e-2, policy=policy)
dqn.compile(Adam(lr=1e-3), metrics=['mae'])

# Okay, now it's time to learn something! We visualize the training here for show, but this
# slows down training quite a lot. You can always safely abort the training prematurely using
# Ctrl + C.

weights_filename = '/Users/swathinayak/PycharmProjects/Contra-bot/dqn_contra_weights.h5f'
checkpoint_weights_filename = '/Users/swathinayak/PycharmProjects/Contra-bot/dqn_contra_weights_step.h5f'
log_filename = '/Users/swathinayak/PycharmProjects/Contra-bot/dqn_contra_log.json'
# callbacks = [ModelIntervalCheckpoint(checkpoint_weights_filename, interval=250000)]
callbacks = [ModelIntervalCheckpoint(checkpoint_weights_filename, interval=250)]
callbacks += [FileLogger(log_filename, interval=100)]
# dqn.fit(env, callbacks=callbacks, nb_steps=1750000, log_interval=10000)

dqn.fit(env, callbacks=callbacks, nb_steps=50000, log_interval=100, visualize=True)


# After training is done, we save the final weights.
dqn.save_weights(weights_filename, overwrite=True)

parser = argparse.ArgumentParser()
parser.add_argument('filename', type=str, help='The filename of the JSON log generated during training.')
parser.add_argument('--output', type=str, default=None, help='The output file. If not specified, the log will only be displayed.')
parser.add_argument('--figsize', nargs=2, type=float, default=None, help='The size of the figure in `width height` format specified in points.')
args = parser.parse_args()

# You can use visualize_log to easily view the stats that were recorded during training. Simply
# provide the filename of the `FileLogger` that was used in `FileLogger`.
visualize_log(args.filename, output=args.output, figsize=args.figsize)


# Finally, evaluate our algorithm for 5 episodes.
dqn.test(env, nb_episodes=5, visualize=True)