In [None]:
import logging

import numpy as np
import tensorflow as tf
print(tf.__version__)
from tqdm.notebook import tqdm

import sys
sys.path.append("agent/util")
from simulation import EpisodeRunner,logger as simulation_logger
simulation_logger.setLevel(logging.ERROR)

from dnn_model import DataSetCreationActionSolver
from strategy import Strategy,RandomMixStrategy

In [None]:
#####################################
###parameters for simulator
simulator_cfg_file = "cfg/simulator_round2.cfg"
metric_period = 200
earlyStoppingDelayIndex=1.6
bufferLength=10000
DEBUG=False

###parameters for DNN agent
dnnCheckpointFileName="round2_18_18_5_r1.ckpt"
dnnNumSegmentInbound=18
dnnNumSegmentOutbound=18
dimVehiclesVec=64
unitsDict={
    'inbound':[64,64,64,64],
}

###parameters for rule base agent
maxDepth=4
timeThresForCalcReward=10.2
prohibitDecreasingGoalDistance=True
prohibitDecreasingSpeed=True

###parameters for mixing strategy 
dnnStrategyRatio=0.5
ruleStrategyRatio=0.5

###parameters for actionSolver
tfrecordPath="tfrecord/round2_18_18_8.tfrecord"
saveNumSegmentInbound = 18
saveNumSegmentOutbound = 18

#####################################
strategy=RandomMixStrategy({
    Strategy.createStrategy(
        'run_distance',
        {
            'maxDepth':maxDepth,
            'timeThresForCalcReward':timeThresForCalcReward,
            'prohibitDecreasingGoalDistance':prohibitDecreasingGoalDistance,
            'prohibitDecreasingSpeed':prohibitDecreasingSpeed,
        }
    ):ruleStrategyRatio,
    Strategy.createStrategy(
        'dnn_run_distance',{
            'checkpointPath':"ckpt/"+dnnCheckpointFileName,
            'numSegmentInbound':dnnNumSegmentInbound,
            'numSegmentOutbound':dnnNumSegmentOutbound,
            'unitsDict':unitsDict,
            'dimVehiclesVec':dimVehiclesVec,
        }
    ):dnnStrategyRatio,
})

dataSetCreator=DataSetCreationActionSolver(
    saveNumSegmentInbound,
    saveNumSegmentOutbound,
    strategy=strategy
)
runner=EpisodeRunner(
    dataSetCreator,
    simulator_cfg_file,
    metric_period,
    debug=DEBUG
)
            
runner.runLoop(
    "strategy",
    earlyStoppingDelayIndex,
    breakReplayBufferLength=bufferLength,
    tqdm=tqdm
)

dataSetCreator.createTFRecord(tfrecordPath)