In [None]:
import logging

import numpy as np
import tensorflow as tf
print(tf.__version__)
from tqdm.notebook import tqdm

from tf_agents.agents.dqn import dqn_agent
from tf_agents.networks import sequential
from tf_agents.utils import common
from tf_agents.policies import random_tf_policy
from tf_agents.trajectories import trajectory
from tf_agents.trajectories.time_step import time_step_spec,TimeStep,StepType
from tf_agents.specs import BoundedArraySpec,tensor_spec
from tf_agents.replay_buffers import tf_uniform_replay_buffer

import sys
sys.path.append("agent/util")
from simulation import BridgeAgent, EpisodeRunner,_BaseActionSolver,logger as simulation_logger
from signal_state_util import WorldSignalState,SignalState,LaneVehicleNumCalc
simulation_logger.setLevel(logging.ERROR)
from road_tracer import RoadTracer
from dnn_model import RunDistanceModel

class _DenseOneSignalQNet():
    def __init__(self,
             num_states,
             fc_layer_params = (100, 50, 30 , 20),
        ):
        self.num_actions=8
        self.num_states=num_states
        self.fc_layer_params=fc_layer_params
        
        self.time_step_spec=time_step_spec(tf.TensorSpec([self.num_states]))
        self.action_spec=tensor_spec.from_spec(BoundedArraySpec((), np.int64,name="action", minimum=0, maximum=self.num_actions-1))
        self.random_policy = random_tf_policy.RandomTFPolicy(self.time_step_spec,self.action_spec)
        
    def _createNetwork(self):
        dense_layers = [self.dense_layer(num_units) for num_units in self.fc_layer_params]
        q_values_layer = tf.keras.layers.Dense(
            self.num_actions,
            activation=None,
            kernel_initializer=tf.keras.initializers.RandomUniform(minval=-0.03, maxval=0.03),
            bias_initializer=tf.keras.initializers.Constant(-0.2)
        )
        return sequential.Sequential(dense_layers + [q_values_layer])
        
    @staticmethod
    def dense_layer(num_units):
        return tf.keras.layers.Dense(
            num_units,
            activation=tf.keras.activations.relu,
            kernel_initializer=tf.keras.initializers.VarianceScaling(
                scale=2.0,
                mode='fan_in',
                distribution='truncated_normal')
        )

    def createDqnAgent(self,
            learningRate,
            targetUpdatePeriod=10,
            epsilon=0.1,
        ):
        q_net=self._createNetwork()
        
        agent = dqn_agent.DqnAgent(
            self.time_step_spec,
            self.action_spec,
            q_network=q_net,
            optimizer=tf.keras.optimizers.Adam(learning_rate=learningRate),
            td_errors_loss_fn=common.element_wise_squared_loss,
            train_step_counter=tf.Variable(0),
            target_update_period=targetUpdatePeriod,
            epsilon_greedy=epsilon,
        )
        agent.initialize()
        
        # (Optional) Optimize by wrapping some of the code in a graph using TF function.
        agent.train = common.function(agent.train)
        
        agent.train_step_counter.assign(0)
        
        return agent
    
    @staticmethod
    def dense_layer(num_units):
        return tf.keras.layers.Dense(
            num_units,
            activation=tf.keras.activations.relu,
            kernel_initializer=tf.keras.initializers.VarianceScaling(
                scale=2.0,
                mode='fan_in',
                distribution='truncated_normal')
        )

class EmbeddingOneSignalQNet(_DenseOneSignalQNet):
    def __init__(self,
            checkpointPath,
            numSegmentInbound=9,
            numSegmentOutbound=9,
            segmentLength=25,
            fc_layer_params = (100,50,30,20),
        ):
        self.segmentLength=segmentLength
        
        #embeddingは前処理として行う（Q学習における学習を行わず、重みは固定して使用する）
        self.model=RunDistanceModel(
                numSegmentInbound,
                numSegmentOutbound,
            )
        self.model.load_weights(checkpointPath)
        
        numTimeState=1
        numEmbeddingState=self.model.getEmbeddedVehiclesVecLength()
        numPhaseProbState=8
        numState=numTimeState+numEmbeddingState+numPhaseProbState
        super().__init__(numState,fc_layer_params)
        
    def calcState(self,tracer,interId,current_time):
        _embeddedVec,_prob = self.model.calcEmbeddingVec(tracer,interId,self.segmentLength)
        
        state=np.concatenate([
            np.array([observations.current_time]), #shape = [1] 
            _embeddedVec.mean(axis=0), #shape = [Dv] 
            _prob, #shape = [8] 
        ]) #shape = [Dv+9] 
        return state
    
class LaneVehicleNumDenseOneSignalQNet(_DenseOneSignalQNet):
    def __init__(self,
            fc_layer_params = (100, 50, 30 , 20),
        ):
        numTimeState=1
        numLaneState=12
        numPhaseState=8
        numState=numTimeState+numLaneState+numPhaseState
        super().__init__(numState,fc_layer_params)
        
    def calcState(self,tracer,interId,current_time):
        timeState=[current_time]
        
        #公式のlane_vehicle_numは、「正しいレーンにレーン移動する前の車両」を含むため誤りの原因になるので使用しない
        #代わりにroute情報から計算したlane_vehicle_numを使用
        laneState=tracer.calcNumVehicleOnLane(interId)
        assert(len(laneState)==12)
        
        _signalState=tracer.worldSignalState.signalStateDict[interId]
        phaseState=_signalState.getPassableEncodingWithoutRightTurn()
        assert(len(phaseState)==8)
        state=timeState+laneState+phaseState
        return state

class SegmentedLaneVehicleNumAndSpeedDenseOneSignalQNet(_DenseOneSignalQNet):
    def __init__(self,
            fc_layer_params = (100, 50, 30 , 20),
            numSegment=5,
            segmentLength=25,
        ):
        self.segmentLength=segmentLength
        self.numSegment=numSegment
        
        numTimeState=1
        numLaneVehicleNumState=12*numSegment
        numLaneVehicleSpeedState=12*numSegment
        numPhaseState=8
        numState=numTimeState+numLaneVehicleNumState+numLaneVehicleSpeedState+numPhaseState
        super().__init__(numState,fc_layer_params)
        
    def calcState(self,tracer,interId,current_time):
        timeState=[current_time]

        laneVehicleNumState,laneVehicleSpeedState=tracer.calcVehicleNumAndSpeedOnSegmentedInboundLane(
            interId,
            self.numSegment,
            self.segmentLength
        )
                
        _signalState=tracer.worldSignalState.signalStateDict[interId]
        phaseState=_signalState.getPassableEncodingWithoutRightTurn()
        assert(len(phaseState)==8)
        state=timeState+laneVehicleNumState+laneVehicleSpeedState+phaseState
        return state
    
class DebugPrinter:
    def __init__(self,agent):
        self.agent=agent
        self.enablePrint=True
        self.prohibitPrintActionDicision=False
        self.prohibitPrintReward=False
        self.prohibitPrintLoss=False
        
        self.printLossIntervalStep=100
        self.trainResultList=[]
        
    def printActionDicision(self,timeStep,select_phase,policyName,current_time):
        if self.enablePrint and not self.prohibitPrintActionDicision:
            step_type=timeStep.step_type.numpy()[0]
            if step_type!=StepType.LAST:
                print("****************** action dicision: time={} *****************".format(current_time))
                if policyName!="random_tf_policy":
                    print("---q network input----------")
                    print("observation",timeStep.observation.numpy())
                    print("---q network output-----")
                    outputs,_=self.agent._q_network(timeStep.observation,step_type=timeStep.step_type)
                    print("q_value",outputs.numpy())
                print("---dicision----------")
                print("{} : select_phase={}".format(policyName,select_phase))
                print("-------------")

    def printReward(self,last_select_phase,reward,observations,roadNet):
        if self.enablePrint and not self.prohibitPrintActionDicision:
            current_time=observations.current_time
            delayIndex=observations.vehicleDS.calcDelayIndex(roadNet.roadDataSet)
            print("****************** reward: time={} *****************".format(current_time))
            print("reward={:.3}, delayIndex={:.3}, last_select_phase={}".format(
                reward,
                float(delayIndex),
                last_select_phase
            ))
            print("--------")
            
    def printLoss(self,step,train_loss):
        self.trainResultList.append((step,train_loss))
        if self.enablePrint and not self.prohibitPrintLoss:
            if step % self.printLossIntervalStep == 0:
                print('step = {0}: loss = {1}'.format(step, train_loss))

class RewardCalc:
    def __init__(self,rewardType="di"):
        self.rewardType=rewardType
    def calc(self,observations,interId,roadNet):
        if self.rewardType=="di":
            di=observations.vehicleDS.calcDelayIndex(roadNet.roadDataSet)
            return 1/di
        elif self.rewardType=="inbound_out" or self.rewardType=="into_outbound":
            if observations.rewards is not None:
                arrInOut=np.array(observations.rewards[interId])
                idx1=1 if rewardType=="inbound_out" else 0
                idx0=0 if rewardType=="inbound_out" else 12
                return arrInOut[idx0:idx0+12,idx1].sum()/30
            else:
                return 0
        else:
            raise Exception("not expected")

class ReplayBuffer:
    def __init__(self,agent,replayBufferMaxLength,batchSize):
        self.replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=agent.collect_data_spec,
            batch_size=1,
            max_length=replayBufferMaxLength
        )
        self.dataset = self.replay_buffer.as_dataset(
            num_parallel_calls=3, 
            sample_batch_size=batchSize, 
            num_steps=2,
        ).prefetch(3)
        self.iterator=iter(self.dataset)
    def getNextExperience(self):
        ex,_=next(self.iterator)        
        return ex
    def add(self,traj):
        self.replay_buffer.add_batch(traj)
    def getLength(self):
        return self.replay_buffer.num_frames().numpy()
    
class SignalDQNActionSolver(_BaseActionSolver):
    # 一つの信号ごとにその最適なフェーズ（を決めるQ値）を出力するDQNを実行する。
    # 都市にある異なる信号でDQNは共有され、同じDQNで学習や推論を行う。
    def __init__(self,
            qNet,
            learningRate=1e-3,
            targetUpdatePeriod=10,
            epsilon=0.1,
            discount=0.9,
            replayBufferMaxLength = 320,
            batchSize=64,
        ):
        super().__init__()
        self.discount=discount
        self.qNet=qNet
        self.agent = qNet.createDqnAgent(
            learningRate,
            targetUpdatePeriod,
            epsilon
        )
        self.debugPrinter=DebugPrinter(self.agent)
        self.rewardCalc=RewardCalc("di")
        
        self.replayBufferMaxLength=replayBufferMaxLength
        self.batchSize=batchSize
        self.interIdToBufDict={}

    def startFirstEpisode(self,signalizedInterIdList):
        #episodeを超えて存在
        for interId in signalizedInterIdList:
            self.interIdToBufDict[interId] = ReplayBuffer(
                self.agent,
                self.replayBufferMaxLength,
                self.batchSize
            )
        
    def getTrainProgress(self):
        li=self.debugPrinter.trainResultList
        if len(li)>=1:
            trainStep,trainLoss = li[-1]
        else:
            trainStep=0
            trainLoss=0.
        return trainStep,trainLoss
    
    def getBufferLength(self):
        if len(self.interIdToBufDict) == 0:
            return 0
        else:
            for buf in self.interIdToBufDict.values():
                #return first buffer length (each buf has same length)
                return buf.getLength()
    
    def _createTimeStep(self,observations,interId,stepType,exitType):
        tracer=self._createRoadTracer(observations)
        state=self.qNet.calcState(tracer,interId,observations.current_time)
        
        extra_reward = self.rewardCalc.calc(observations,interId,self.roadNet)
        assert(0<=extra_reward and extra_reward<=1)
        reward=extra_reward if stepType==StepType.MID else 0
        discount = self.discount if stepType!=StepType.LAST else 0
        
        if stepType!=StepType.FIRST:
            signalState=self.worldSignalState.signalStateDict[interId]
            self.debugPrinter.printReward(
                signalState.prevPolicyStepForDQN.action.numpy()[0]+1,
                reward,
                observations,
                self.roadNet,
            )
        
        return TimeStep(
            tf.constant([stepType]),
            tf.constant([reward],dtype=tf.float32),
            tf.constant([discount],dtype=tf.float32),
            tf.constant([state],dtype=tf.float32)
        )
    
    def _collect(self,stepType,interId,currentTimeStep):
        if stepType==StepType.FIRST:
            if self.lastEpisodeWorldSignalState is not None:
                lastSignalState=self.lastEpisodeWorldSignalState.signalStateDict[interId]
            else:
                lastSignalState = None
        else:
            lastSignalState=self.worldSignalState.signalStateDict[interId]

        if lastSignalState is not None:
            # save to the replay buffer.
            traj=trajectory.from_transition(
                lastSignalState.prevTimeStepForDQN,
                lastSignalState.prevPolicyStepForDQN,
                currentTimeStep
            )
            self.interIdToBufDict[interId].add(traj)

    def _train(self,interId):
        # Sample a batch of data from the buffer and update the agent's network.
        experience = self.interIdToBufDict[interId].getNextExperience()
        train_loss = self.agent.train(experience).loss
        self.debugPrinter.printLoss(
            self.agent.train_step_counter.numpy(),
            train_loss.numpy()
        )
        
    def _action(self,timeStep,signalState,policy,current_time):
        policyStep = policy.action(timeStep)
        select_phase=policyStep.action.numpy()[0]+1
        
        signalState.setPreviousStepInfo(
            prevTimeStepForDQN=timeStep,
            prevPolicyStepForDQN=policyStep,
        )
        
        self.debugPrinter.printActionDicision(
            timeStep,
            select_phase,
            policy.name,
            current_time
        )
        
        return select_phase
    
    def decideActions(self,
            observations,
            prevActCountInEpisode,
            runType="eval",
            exitType=False,
            debug=False
        ):
        if runType=="eval":
            policy = self.agent.policy
            collect=False
            train=False
            self.debugPrinter.enablePrint=True and debug
        elif runType=="train":
            policy = self.agent.collect_policy
            collect=True
            train=True
            self.debugPrinter.enablePrint=False and debug
        elif runType=="random":
            policy = self.qNet.random_policy
            collect=True
            train=False
            self.debugPrinter.enablePrint=False and debug
        else:
            raise Exception("not expected")
        
        if prevActCountInEpisode==0:
            stepType = StepType.FIRST 
        else:
            stepType = StepType.LAST if exitType is not None else StepType.MID
        
        actions={}
        for interId,signalState in self.worldSignalState.signalStateDict.items():
            #episode's FIRST or MID or LAST step
            
            #convert observations into input for Q-netowrk
            currentTimeStep=self._createTimeStep(observations,interId,stepType,exitType)
            
            if collect:
                self._collect(stepType,interId,currentTimeStep)
                
            select_phase=self._action(currentTimeStep,signalState,policy,observations.current_time) #need call even if step is LAST for saving last action for next episode
            
            if stepType!=StepType.LAST:
                if select_phase != signalState.phase:
                    signalState.changePhase(select_phase,observations.current_time)
                    actions[interId]=select_phase
                    
        if stepType!=StepType.FIRST and train:
            for interId in self.worldSignalState.signalStateDict:
                self._train(interId)
            
        return actions


In [None]:
#####################################
simulator_cfg_file = "cfg/simulator_test1.cfg"
metric_period = 1
initialBufferLength=100
numTrainEpisode = 1000
earlyStoppingDelayIndex=1.6

embeddingWeightPath="ckpt/test2_9_9_10000_1_1.ckpt"

DEBUG=False
#####################################

#dqnRunner=SignalDQNActionSolver(SegmentedLaneVehicleNumAndSpeedDenseOneSignalQNet())
dqnRunner=SignalDQNActionSolver(EmbeddingOneSignalQNet(embeddingWeightPath))
runner=EpisodeRunner(dqnRunner,simulator_cfg_file,metric_period,debug=DEBUG)

            
print("collecting experience using random policy")
runner.runLoop("random",earlyStoppingDelayIndex,breakReplayBufferLength=initialBufferLength,tqdm=tqdm)

print("training with collecting experience using dqn epsilon-greedy policy")
runner.runLoop("train",earlyStoppingDelayIndex,breakNumEpsiode=numTrainEpisode,eval_every_n_episode=10,tqdm=tqdm)

In [None]:
#####################################
simulator_cfg_file = "cfg/simulator_test2.cfg"
metric_period = 1
initialBufferLength=100
numTrainEpisode = 1000
earlyStoppingDelayIndex=1.6

embeddingWeightPath="ckpt/test2_9_9_10000_1_1.ckpt"

DEBUG=False
#####################################

#dqnRunner=SignalDQNActionSolver(LaneVehicleNumDenseOneSignalQNet())
dqnRunner=SignalDQNActionSolver(EmbeddingOneSignalQNet(embeddingWeightPath))
runner=EpisodeRunner(dqnRunner,simulator_cfg_file,metric_period,debug=DEBUG)

            
print("collecting experience using random policy")
runner.runLoop("random",earlyStoppingDelayIndex,breakReplayBufferLength=initialBufferLength,tqdm=tqdm)

print("training with collecting experience using dqn epsilon-greedy policy")
runner.runLoop("train",earlyStoppingDelayIndex,breakNumEpsiode=numTrainEpisode,eval_every_n_episode=10,tqdm=tqdm)

In [None]:
# #####################################
# simulator_cfg_file = "cfg/simulator_warm_up.cfg"
# metric_period = 1
# initialBufferLength=100
# numTrainEpisode = 1000
# earlyStoppingDelayIndex=1.3

# DEBUG=False
# #####################################

# dqnRunner=SignalDQNActionSolver(LaneVehicleNumDenseOneSignalQNet())
# runner=EpisodeRunner(dqnRunner,simulator_cfg_file,metric_period,debug=DEBUG)

            
# print("collecting experience using random policy")
# runner.runLoop("random",earlyStoppingDelayIndex,breakReplayBufferLength=initialBufferLength,tqdm=tqdm)

# print("training with collecting experience using dqn epsilon-greedy policy")
# runner.runLoop("train",earlyStoppingDelayIndex,breakNumEpsiode=numTrainEpisode,eval_every_n_episode=10,tqdm=tqdm)

In [None]:
import pandas as pd
df=pd.DataFrame(runner.evalRecordList,columns=['time','served','DI','train','loss'])
df[['DI','loss']].plot()
df[['time']].plot()

In [None]:
runner.runLoop("eval",tqdm=tqdm)
runner.export(tqdm)