In [None]:
import os
import numpy as np
import tensorflow as tf
print(tf.__version__)

import sys
sys.path.append("agent/util")
from dnn_model import RunDistanceDataSetDecoder,RunDistanceModel

class Trainer:
    def __init__(self,
            numSegmentInbound,
            numSegmentOutbound,
            unitsDict,
            dimVehiclesVec,
            readWeightPath=None
        ):
        self.dec=RunDistanceDataSetDecoder(numSegmentInbound,numSegmentOutbound)

        self.model=RunDistanceModel(
            numSegmentInbound,
            numSegmentOutbound,
            unitsDict=unitsDict,
            dimVehiclesVec=dimVehiclesVec,
        )
        self.call_sample()
        self.model.summary()
        self.model.compile(optimizer="adam", loss="mse")
        if readWeightPath is not None:
            self.model.load_weights(readWeightPath)
            print("weight loaded")
            
        print("------------")
    
    def call_sample(self):
        input_dims=self.model.getInputDim()
        print("input_dims =",input_dims)

        sample=tf.ones([1,input_dims])
        print("-------")
        result=self.model(sample)
        print("-------")
        return result
    
    def save(self,save_path):
        tf.saved_model.save(
            self.model,
            save_path,
            signatures={
                "model":self.model.call.get_concrete_function(
                    tf.TensorSpec([None,self.model.getInputDim()], tf.float32)
                ),
                "embedding":self.model.embeddingLayer.call.get_concrete_function(
                    tf.TensorSpec([None,self.model.embeddingLayer.getInputDim()], tf.float32)
                )
            }
        )

    def evaluate(self,evalTfrecordPath):
        eval_ds=self.dec.load(evalTfrecordPath).batch(64).prefetch(1)
        self.model.evaluate(eval_ds)

    def train(self,
            trainTfrecordPathList,
            saveWeightPath,
            epochs,
            shuffle_buf_len,
            evalTfrecordPath=None
        ):
        for idx in range(len(trainTfrecordPathList)):
            _train_ds=self.dec.load(trainTfrecordPathList[idx])
            if idx==0:
                train_ds=_train_ds
            else:
                train_ds=train_ds.concatenate(_train_ds)

        train_ds=train_ds.shuffle(shuffle_buf_len).batch(64).prefetch(1)

        if evalTfrecordPath is not None:
            eval_ds=self.dec.load(evalTfrecordPath).batch(64).prefetch(1)
            self.model.fit(
                train_ds,
                epochs=epochs,
                validation_data=eval_ds
            )
        else:
            self.model.fit(
                train_ds,
                epochs=epochs,
            )

        self.model.save_weights(saveWeightPath)
        print("weights saved")    

In [None]:
###################################
### model parameter
numSegmentInbound=18
numSegmentOutbound=18
unitsDict={
    'inbound':[64,64,64,64],
}
dimVehiclesVec=64
readWeightPath="ckpt/round2_18_18_5_r1.ckpt"
###################################
trainer=Trainer(
    numSegmentInbound,
    numSegmentOutbound,
    unitsDict=unitsDict,
    dimVehiclesVec=dimVehiclesVec,
    readWeightPath=readWeightPath
)

In [None]:
#train
#######################
### training/evaluation parameter
trainTfrecordPathList=[
    "tfrecord/round2_18_18_5.tfrecord",
    "tfrecord/round2_18_18_6.tfrecord",
    "tfrecord/round2_18_18_7.tfrecord"
]
evalTfrecordPath=None
saveWeightPath="ckpt/round2_18_18_5&6&7_r2.ckpt"
epochs=50
shuffle_buf_len=308381*len(trainTfrecordPathList)
#######################
trainer.train(
    trainTfrecordPathList,
    saveWeightPath,
    epochs,
    shuffle_buf_len,
    evalTfrecordPath
)

In [None]:
#eval
#######################
### evaluation parameter
evalTfrecordPath="tfrecord/round2_18_18_5.tfrecord"
#######################
trainer.evaluate(evalTfrecordPath)

In [None]:
#save model
trainer.save("saved_model/"+os.path.splitext(readWeightPath.split("/")[-1])[0])

In [None]:
#load model
imported = tf.saved_model.load("saved_model/"+os.path.splitext(readWeightPath.split("/")[-1])[0])
embedding_func = imported.signatures["embedding"]
print("embedding",embedding_func)
print(
    embedding_func(tf.ones([1,597]))["output_3"]
)