In [2]:
# IMPORTS
##########################

import agent
import environment
import doubledqn
import tools
import memory

import tensorflow as tf

import numpy as np
import matplotlib.pyplot as plt
import time
import itertools

Using TensorFlow backend.


In [None]:
import importlib
importlib.reload(agent)
importlib.reload(environment)
importlib.reload(doubledqn)
importlib.reload(tools)
importlib.reload(memory)

In [2]:
# MAIN
##################################

num_actions = 2
state_shape = (1,11) # State var in rows
memory_size = 20000
gamma = 0.9
target_update_frequency = 30000
num_init_samples_mem = 1000
batch_size = 50
max_episode_length = 100000
optimizer = 'adam'
loss = "mse"
eps = 0.1
env_name = "Simple_Cross"
experiment_id = "Test_linear_model"
monitoring = True # Store variables for TensorBoard monitoring and model_checkpoints

# Define logs directory if monitoring enabled
if monitoring:
    output_dir = tools.get_output_folder("./Logs",experiment_id)
    summary_writer = tf.summary.FileWriter(logdir=output_dir)
else:
    output_dir = None
    summary_writer = None

# Initialize Q-networks (value and target)
q_network = agent.get_model('simple',(state_shape[1],),num_actions)
target_q_network = agent.get_model('simple',(state_shape[1],),num_actions)

# Initialize environment
sumo_env =  environment.Env(    "cross.net.xml",
                                "cross.rou.xml",
                                state_shape,
                                num_actions,
                                use_gui=False
                           )

# Initialize replay memory
mem = memory.ReplayMemory(    memory_size,
                                 state_shape,
                                 num_actions
                             )

# Initialize Double DQN algorithm
ddqn = doubledqn.DoubleDQN(     q_network,
                                target_q_network,
                                mem,
                                gamma,
                                target_update_frequency,
                                num_init_samples_mem,
                                batch_size,
                                optimizer,
                                loss,
                                max_episode_length,
                                sumo_env,
                                output_dir,
                                experiment_id,
                                summary_writer
                            )

# Fill Replay Memory
ddqn.fill_replay(sumo_env)

Instructions for updating:
Colocations handled automatically by placer.
Filling experience replay memory...
...Done


In [3]:
# Train
_ = ddqn.train(  sumo_env, 2, "epsGreedy", eps=eps)

Instructions for updating:
Use tf.cast instead.
Running episode 827 / 1000

KeyboardInterrupt: 

In [103]:
def get_chunks(iterable, chunks=1):
    # This is from http://stackoverflow.com/a/2136090/2073595
    lst = list(iterable)
    return [lst[i::chunks] for i in range(chunks)]

def take(n, iterable):
    "Return first n items of the iterable as a list"
    return list(itertools.islice(iterable, n))

def worker(chunked_param_list):
    best_run_time = 100000
    for params in chunked_param_list:
        
        # Initialize Q-networks (value and target)
        q_network = agent.get_model('simple',(state_shape[1],),num_actions)
        target_q_network = agent.get_model('simple',(state_shape[1],),num_actions)

        # Initialize environment
        sumo_env =  environment.Env(    "cross.net.xml",
                                "cross.rou.xml",
                                state_shape,
                                num_actions,
                                use_gui=False
                           )

        # Initialize replay memory
        mem = memory.ReplayMemory(    memory_size,
                                 state_shape,
                                 num_actions
                             )
        
        # Initialize Double DQN algorithm
        ddqn = doubledqn.DoubleDQN(     q_network = q_network,
                                target_q_network = target_q_network,
                                memory = mem,
                                gamma = params[2],
                                target_update_freq = params[1],
                                num_burn_in = num_init_samples_mem,
                                batch_size = params[0],
                                optimizer = optimizer,
                                loss_func = loss,
                                max_ep_length = max_episode_length,
                                env_name = sumo_env,
                                output_dir = output_dir,
                                experiment_id = experiment_id,
                                summary_writer = summary_writer)
        
        # Fill Replay Memory
        ddqn.fill_replay(sumo_env)
        
        # Train
        _ = ddqn.train(sumo_env, 2, "epsGreedy", eps = params[3])
        
        run_time = ddqn.evaluate_cv(sumo_env, "greedy")
        if run_time < best_run_time:
            best_run_time = run_time
            best_params = params
            
    return best_run_time

def gridsearch(param_grid = param_grid):
    jobs = []
    chunked_param_list = get_chunks(param_grid, chunks = multiprocessing.cpu_count())
    pool = multiprocessing.Pool()
    results = pool.map(worker, chunked_param_list)
    pool.close()
    pool.join()
    print(results)
    # Now combine the results
    sorted_results = reversed(sorted(results, key=lambda x: x[0]))
    return(sorted_results)  # Winner

In [3]:
# define the grid search parameters
batch_size = [10, 20, 40, 60, 80, 100]
target_update_frequency = [10, 50, 100]
gamma = [0.6, 0.7, 0.8, 0.9]
eps = [0.1, 0.2]

param_grid = itertools.product(batch_size, target_update_frequency, gamma, eps)

In [4]:
param_grid

<itertools.product at 0xb2d28c900>

In [5]:
def take(n, iterable):
    "Return first n items of the iterable as a list"
    return list(itertools.islice(iterable, n))

In [6]:
take(20, param_grid)

[(10, 10, 0.6, 0.1),
 (10, 10, 0.6, 0.2),
 (10, 10, 0.7, 0.1),
 (10, 10, 0.7, 0.2),
 (10, 10, 0.8, 0.1),
 (10, 10, 0.8, 0.2),
 (10, 10, 0.9, 0.1),
 (10, 10, 0.9, 0.2),
 (10, 50, 0.6, 0.1),
 (10, 50, 0.6, 0.2),
 (10, 50, 0.7, 0.1),
 (10, 50, 0.7, 0.2),
 (10, 50, 0.8, 0.1),
 (10, 50, 0.8, 0.2),
 (10, 50, 0.9, 0.1),
 (10, 50, 0.9, 0.2),
 (10, 100, 0.6, 0.1),
 (10, 100, 0.6, 0.2),
 (10, 100, 0.7, 0.1),
 (10, 100, 0.7, 0.2)]

In [None]:
gridsearch()

In [19]:
tools.generate_routefile()
sumo_env =  environment.Env(    "cross.net.xml",
                                "cross.rou.xml",
                                state_shape,
                                num_actions,
                                use_gui=False
                           )
ddqn.train(  sumo_env, 1, "epsGreedy", eps=eps)
tools.compute_mean_duration(output_dir)

 Retrying in 1 seconds


In [11]:
import pandas as pd
pd.DataFrame(data)

Unnamed: 0,action,it,next_state,q_values,reward,state
0,0,1,"[[0.0, 0.0, 0.0, 1.0, 19.44, 16.37972210690379...","[[109.82671, -259.65408]]",-3.0,"[[0.0, 0.0, 0.0, 0.0, 19.44, 11.48842193381861..."
1,0,2,"[[0.0, 0.0, 0.0, 2.0, 19.44, 11.90125696071423...","[[-143.08862, -315.6332]]",-11.0,"[[0.0, 0.0, 0.0, 1.0, 19.44, 16.37972210690379..."
2,0,3,"[[0.0, 1.0, 0.0, 3.0, 19.44, 6.486392838847504...","[[-267.5583, -467.6484]]",-28.0,"[[0.0, 0.0, 0.0, 2.0, 19.44, 11.90125696071423..."
3,0,4,"[[0.0, 1.0, 0.0, 3.0, 19.44, 7.276162695880048...","[[-473.28513, -669.4078]]",-40.0,"[[0.0, 1.0, 0.0, 3.0, 19.44, 6.486392838847504..."
4,0,5,"[[0.0, 1.0, 0.0, 3.0, 19.44, 9.706275405400543...","[[-292.64462, -602.15265]]",-40.0,"[[0.0, 1.0, 0.0, 3.0, 19.44, 7.276162695880048..."
5,0,6,"[[0.0, 2.0, 0.0, 4.0, 19.44, 8.45633633997602,...","[[-129.23796, -538.3714]]",-54.0,"[[0.0, 1.0, 0.0, 3.0, 19.44, 9.706275405400543..."
6,0,7,"[[0.0, 2.0, 0.0, 4.0, 19.44, 7.233086045204033...","[[47.91474, -527.0916]]",-60.0,"[[0.0, 2.0, 0.0, 4.0, 19.44, 8.45633633997602,..."
7,0,8,"[[0.0, 4.0, 0.0, 4.0, 19.44, 4.961622690319503...","[[302.29648, -470.5439]]",-71.0,"[[0.0, 2.0, 0.0, 4.0, 19.44, 7.233086045204033..."
8,0,9,"[[0.0, 5.0, 0.0, 6.0, 19.44, 6.898443424791099...","[[428.50565, -494.7269]]",-98.0,"[[0.0, 4.0, 0.0, 4.0, 19.44, 4.961622690319503..."
9,0,10,"[[0.0, 6.0, 0.0, 6.0, 19.44, 6.006485639890047...","[[518.337, -521.4004]]",-111.0,"[[0.0, 5.0, 0.0, 6.0, 19.44, 6.898443424791099..."


In [None]:
pd.DataFrame(data).state[22]

5748