In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import pandas as pd
from IPython.display import HTML
from collections import deque
import importlib


%load_ext autoreload
%autoreload 2

In [None]:
import simulation
import controller
from simulation import *
from main_bus import MainBus

### Initialization

In [None]:
whole_day = 1440

# Parameters

replay_length = 100000
iterations = 144 # num of ticks before training
epochs = 1000 # num of times network is unfrozen, trained
batch_size = 32 # num of samples from replay memory to train on
animate = False
add_arrivals_noise = False

# Model
loggers = []
simulation = Simulation(MainBus, loggers=loggers)
replay_memory = deque(maxlen = replay_length)
training_results = []
training_loss = []
delivered_passeners = []

save_every = 10

### Training

In [None]:
# Main training loop

for epoch in range(epochs):
    print('Epoch #%d' % (epoch+1))
    day_time = 0
    simulation.reset(add_arrivals_noise) # resets everything
    
    while day_time < whole_day:
        training_results.append(simulation.execute(iterations=iterations)) # simulate and collect logs
        

        replay_memory.extend(simulation.controller.replay_memory) # store data into replay memory

        if len(replay_memory)  > batch_size:
            # Get random samples
            training_idx = np.random.choice(len(replay_memory), size = batch_size)
            training_samples = [replay_memory[i] for i in training_idx]

            # Train DQN
            simulation.controller.destination_model.train(training_samples)
        
        day_time += iterations
        
    training_loss.append(simulation.controller.get_total_cost()) # collect loss
    delivered_passeners.append(simulation.controller.num_passengers_delivered)
    print('\r\tdelivered:{}'.format(simulation.controller.num_passengers_delivered))
    
    if epoch % save_every == 0:
        simulation.controller.save_destination_model('decision_model')
        print('### CHECKPOINT ###')

In [None]:
plt.plot(training_loss, 'r-')
plt.title('Loss')
plt.show()

plt.plot(delivered_passeners, 'g-')
plt.title('Delivered')
plt.show()

In [None]:
simulation.reset()
simulation.execute(iterations=700, animate=True)

In [None]:
HTML(simulation.anim.to_html5_video())