In [12]:

import tensorflow as tf
from tensorflow.keras.layers import Flatten, Dense
from tensorflow.keras.models import Sequential



In [13]:
X = tf.keras.layers.Input((32, 10, 12))
X.shape[3]

12

In [17]:
N = X.shape[3]
Q = X.shape[2]
input_shape = X.shape[1:]  # Shape without batch size
X_train_flat = tf.reshape(X, (-1, Q * N))

# Create a sequential model
model = Sequential()

# Add a fully connected layer with sigmoid activation
model.add(Flatten(input_shape=input_shape))
model.add(Dense(64, activation='relu'))  # You can adjust the number of units
model.add(Dense(3, activation='sigmoid'))

ValueError: Failed to find data adapter that can handle input: <class 'keras.src.engine.keras_tensor.KerasTensor'>, (<class 'list'> containing values of types {"<class 'int'>"})

In [6]:


def conv2d(x, output_dims, kernel_size, stride=[1, 1],
           padding='SAME', use_bias=True, activation=tf.nn.relu,
           bn=False, bn_decay=None, is_training=None):
    input_dims = x.get_shape()[-1].value
    kernel_shape = kernel_size + [input_dims, output_dims]
    kernel = tf.Variable(
        tf.glorot_uniform_initializer()(shape=kernel_shape),
        dtype=tf.float32, trainable=True, name='kernel')
    x = tf.nn.conv2d(x, kernel, [1] + stride + [1], padding=padding)
    if use_bias:
        bias = tf.Variable(
            tf.zeros_initializer()(shape=[output_dims]),
            dtype=tf.float32, trainable=True, name='bias')
        x = tf.nn.bias_add(x, bias)
    if activation is not None:
        if bn:
            x = batch_norm(x, is_training=is_training, bn_decay=bn_decay)
        x = activation(x)
    return x


def batch_norm(x, is_training, bn_decay):
    input_dims = x.get_shape()[-1].value
    moment_dims = list(range(len(x.get_shape()) - 1))
    beta = tf.Variable(
        tf.zeros_initializer()(shape=[input_dims]),
        dtype=tf.float32, trainable=True, name='beta')
    gamma = tf.Variable(
        tf.ones_initializer()(shape=[input_dims]),
        dtype=tf.float32, trainable=True, name='gamma')
    batch_mean, batch_var = tf.nn.moments(x, moment_dims, name='moments')

    decay = bn_decay if bn_decay is not None else 0.9
    ema = tf.train.ExponentialMovingAverage(decay=decay)
    # Operator that maintains moving averages of variables.
    ema_apply_op = tf.cond(
        is_training,
        lambda: ema.apply([batch_mean, batch_var]),
        lambda: tf.no_op())

    # Update moving average and return current batch's avg and var.
    def mean_var_with_update():
        with tf.control_dependencies([ema_apply_op]):
            return tf.identity(batch_mean), tf.identity(batch_var)

    # ema.average returns the Variable holding the average of var.
    mean, var = tf.cond(
        is_training,
        mean_var_with_update,
        lambda: (ema.average(batch_mean), ema.average(batch_var)))
    x = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)
    return x


def dropout(x, drop, is_training):
    x = tf.cond(
        is_training,
        lambda: tf.nn.dropout(x, rate=drop),
        lambda: x)
    return x


In [7]:

def placeholder(P, Q, N):
    X = tf.compat.v1.placeholder(
        shape=(None, P, N), dtype=tf.float32, name='X')
    TE = tf.compat.v1.placeholder(
        shape=(None, P + Q, 2), dtype=tf.int32, name='TE')
    label = tf.compat.v1.placeholder(
        shape=(None, Q, N), dtype=tf.float32, name='label')
    is_training = tf.compat.v1.placeholder(
        shape=(), dtype=tf.bool, name='is_training')
    return X, TE, label, is_training


def FC(x, units, activations, bn, bn_decay, is_training, use_bias=True, drop=None):
    if isinstance(units, int):
        units = [units]
        activations = [activations]
    elif isinstance(units, tuple):
        units = list(units)
        activations = list(activations)
    assert type(units) == list
    for num_unit, activation in zip(units, activations):
        if drop is not None:
            x = dropout(x, drop=drop, is_training=is_training)
        x = conv2d(
            x, output_dims=num_unit, kernel_size=[1, 1], stride=[1, 1],
            padding='VALID', use_bias=use_bias, activation=activation,
            bn=bn, bn_decay=bn_decay, is_training=is_training)
    return x


def STEmbedding(SE, TE, T, D, bn, bn_decay, is_training):
    '''
    spatio-temporal embedding
    SE:     [N, D]
    TE:     [batch_size, P + Q, 2] (dayofweek, timeofday)
    T:      num of time steps in one day
    D:      output dims
    retrun: [batch_size, P + Q, N, D]
    '''
    # spatial embedding
    SE = tf.expand_dims(tf.expand_dims(SE, axis=0), axis=0)
    SE = FC(
        SE, units=[D, D], activations=[tf.nn.relu, None],
        bn=bn, bn_decay=bn_decay, is_training=is_training)
    # temporal embedding
    dayofweek = tf.one_hot(TE[..., 0], depth=7)
    timeofday = tf.one_hot(TE[..., 1], depth=T)
    TE = tf.concat((dayofweek, timeofday), axis=-1)
    TE = tf.expand_dims(TE, axis=2)
    TE = FC(
        TE, units=[D, D], activations=[tf.nn.relu, None],
        bn=bn, bn_decay=bn_decay, is_training=is_training)
    return tf.add(SE, TE)


def spatialAttention(X, STE, K, d, bn, bn_decay, is_training):
    '''
    spatial attention mechanism
    X:      [batch_size, num_step, N, D]
    STE:    [batch_size, num_step, N, D]
    K:      number of attention heads
    d:      dimension of each attention outputs
    return: [batch_size, num_step, N, D]
    '''
    D = K * d
    X = tf.concat((X, STE), axis=-1)
    # [batch_size, num_step, N, K * d]
    query = FC(
        X, units=D, activations=tf.nn.relu,
        bn=bn, bn_decay=bn_decay, is_training=is_training)
    key = FC(
        X, units=D, activations=tf.nn.relu,
        bn=bn, bn_decay=bn_decay, is_training=is_training)
    value = FC(
        X, units=D, activations=tf.nn.relu,
        bn=bn, bn_decay=bn_decay, is_training=is_training)
    # [K * batch_size, num_step, N, d]
    query = tf.concat(tf.split(query, K, axis=-1), axis=0)
    key = tf.concat(tf.split(key, K, axis=-1), axis=0)
    value = tf.concat(tf.split(value, K, axis=-1), axis=0)
    # [K * batch_size, num_step, N, N]
    attention = tf.matmul(query, key, transpose_b=True)
    attention /= (d ** 0.5)
    attention = tf.nn.softmax(attention, axis=-1)
    # [batch_size, num_step, N, D]
    X = tf.matmul(attention, value)
    X = tf.concat(tf.split(X, K, axis=0), axis=-1)
    X = FC(
        X, units=[D, D], activations=[tf.nn.relu, None],
        bn=bn, bn_decay=bn_decay, is_training=is_training)
    return X


def temporalAttention(X, STE, K, d, bn, bn_decay, is_training, mask=True):
    '''
    temporal attention mechanism
    X:      [batch_size, num_step, N, D]
    STE:    [batch_size, num_step, N, D]
    K:      number of attention heads
    d:      dimension of each attention outputs
    return: [batch_size, num_step, N, D]
    '''
    D = K * d
    X = tf.concat((X, STE), axis=-1)
    # [batch_size, num_step, N, K * d]
    query = FC(
        X, units=D, activations=tf.nn.relu,
        bn=bn, bn_decay=bn_decay, is_training=is_training)
    key = FC(
        X, units=D, activations=tf.nn.relu,
        bn=bn, bn_decay=bn_decay, is_training=is_training)
    value = FC(
        X, units=D, activations=tf.nn.relu,
        bn=bn, bn_decay=bn_decay, is_training=is_training)
    # [K * batch_size, num_step, N, d]
    query = tf.concat(tf.split(query, K, axis=-1), axis=0)
    key = tf.concat(tf.split(key, K, axis=-1), axis=0)
    value = tf.concat(tf.split(value, K, axis=-1), axis=0)
    # query: [K * batch_size, N, num_step, d]
    # key:   [K * batch_size, N, d, num_step]
    # value: [K * batch_size, N, num_step, d]
    query = tf.transpose(query, perm=(0, 2, 1, 3))
    key = tf.transpose(key, perm=(0, 2, 3, 1))
    value = tf.transpose(value, perm=(0, 2, 1, 3))
    # [K * batch_size, N, num_step, num_step]
    attention = tf.matmul(query, key)
    attention /= (d ** 0.5)
    # mask attention score
    if mask:
        batch_size = tf.shape(X)[0]
        num_step = X.get_shape()[1].value
        N = X.get_shape()[2].value
        mask = tf.ones(shape=(num_step, num_step))
        mask = tf.linalg.LinearOperatorLowerTriangular(mask).to_dense()
        mask = tf.expand_dims(tf.expand_dims(mask, axis=0), axis=0)
        mask = tf.tile(mask, multiples=(K * batch_size, N, 1, 1))
        mask = tf.cast(mask, dtype=tf.bool)
        attention = tf.compat.v2.where(
            condition=mask, x=attention, y=-2 ** 15 + 1)
    # softmax
    attention = tf.nn.softmax(attention, axis=-1)
    # [batch_size, num_step, N, D]
    X = tf.matmul(attention, value)
    X = tf.transpose(X, perm=(0, 2, 1, 3))
    X = tf.concat(tf.split(X, K, axis=0), axis=-1)
    X = FC(
        X, units=[D, D], activations=[tf.nn.relu, None],
        bn=bn, bn_decay=bn_decay, is_training=is_training)
    return X


def gatedFusion(HS, HT, D, bn, bn_decay, is_training):
    '''
    gated fusion
    HS:     [batch_size, num_step, N, D]
    HT:     [batch_size, num_step, N, D]
    D:      output dims
    return: [batch_size, num_step, N, D]
    '''
    XS = FC(
        HS, units=D, activations=None,
        bn=bn, bn_decay=bn_decay,
        is_training=is_training, use_bias=False)
    XT = FC(
        HT, units=D, activations=None,
        bn=bn, bn_decay=bn_decay,
        is_training=is_training, use_bias=True)
    z = tf.nn.sigmoid(tf.add(XS, XT))
    H = tf.add(tf.multiply(z, HS), tf.multiply(1 - z, HT))
    H = FC(
        H, units=[D, D], activations=[tf.nn.relu, None],
        bn=bn, bn_decay=bn_decay, is_training=is_training)
    return H


def STAttBlock(X, STE, K, d, bn, bn_decay, is_training, mask=True):
    HS = spatialAttention(X, STE, K, d, bn, bn_decay, is_training)
    HT = temporalAttention(X, STE, K, d, bn, bn_decay, is_training, mask=mask)
    H = gatedFusion(HS, HT, K * d, bn, bn_decay, is_training)
    return tf.add(X, H)


def transformAttention(X, STE_P, STE_Q, K, d, bn, bn_decay, is_training):
    '''
    transform attention mechanism
    X:      [batch_size, P, N, D]
    STE_P:  [batch_size, P, N, D]
    STE_Q:  [batch_size, Q, N, D]
    K:      number of attention heads
    d:      dimension of each attention outputs
    return: [batch_size, Q, N, D]
    '''
    D = K * d
    # query: [batch_size, Q, N, K * d]
    # key:   [batch_size, P, N, K * d]
    # value: [batch_size, P, N, K * d]
    query = FC(
        STE_Q, units=D, activations=tf.nn.relu,
        bn=bn, bn_decay=bn_decay, is_training=is_training)
    key = FC(
        STE_P, units=D, activations=tf.nn.relu,
        bn=bn, bn_decay=bn_decay, is_training=is_training)
    value = FC(
        X, units=D, activations=tf.nn.relu,
        bn=bn, bn_decay=bn_decay, is_training=is_training)
    # query: [K * batch_size, Q, N, d]
    # key:   [K * batch_size, P, N, d]
    # value: [K * batch_size, P, N, d]
    query = tf.concat(tf.split(query, K, axis=-1), axis=0)
    key = tf.concat(tf.split(key, K, axis=-1), axis=0)
    value = tf.concat(tf.split(value, K, axis=-1), axis=0)
    # query: [K * batch_size, N, Q, d]
    # key:   [K * batch_size, N, d, P]
    # value: [K * batch_size, N, P, d]
    query = tf.transpose(query, perm=(0, 2, 1, 3))
    key = tf.transpose(key, perm=(0, 2, 3, 1))
    value = tf.transpose(value, perm=(0, 2, 1, 3))
    # [K * batch_size, N, Q, P]
    attention = tf.matmul(query, key)
    attention /= (d ** 0.5)
    attention = tf.nn.softmax(attention, axis=-1)
    # [batch_size, Q, N, D]
    X = tf.matmul(attention, value)
    X = tf.transpose(X, perm=(0, 2, 1, 3))
    X = tf.concat(tf.split(X, K, axis=0), axis=-1)
    X = FC(
        X, units=[D, D], activations=[tf.nn.relu, None],
        bn=bn, bn_decay=bn_decay, is_training=is_training)
    return X


def GMAN(X, TE, SE,  T, bn, bn_decay, is_training):
    '''
    GMAN
    X:       [batch_size, P, N]
    TE:      [batch_size, P + Q, 2] (time-of-day, day-of-week)
    SE:      [N, K * d]
    P:       number of history steps
    Q:       number of prediction steps
    T:       one day is divided into T steps
    L:       number of STAtt blocks in the encoder/decoder
    K:       number of attention heads
    d:       dimension of each attention head outputs
    return:  [batch_size, Q, N]
    '''

    P = config.AGENT.HISTORY_STEPS
    Q = config.AGENT.PREDICTION_STEPS
    L = config.AGENT.NUMBER_OF_STATT_BLOCKS
    K = config.AGENT.NUMBER_OF_ATTENTION_HEADS
    d = config.AGENT.HEAD_ATTENTION_OUTPUT_DIM

    D = K * d
    # input
    X = tf.expand_dims(X, axis=-1)

    X = FC(
        X, units=[D, D], activations=[tf.nn.relu, None],
        bn=bn, bn_decay=bn_decay, is_training=is_training)
    # STE
    STE = STEmbedding(SE, TE, T, D, bn, bn_decay, is_training)
    STE_P = STE[:, : P]
    STE_Q = STE[:, P:]
    # encoder
    for _ in range(L):
        X = STAttBlock(X, STE_P, K, d, bn, bn_decay, is_training)
    # transAtt
    X = transformAttention(
        X, STE_P, STE_Q, K, d, bn, bn_decay, is_training)
    # decoder
    for _ in range(L):
        X = STAttBlock(X, STE_Q, K, d, bn, bn_decay, is_training)
    # output
    X = FC(
        X, units=[D, 1], activations=[tf.nn.relu, None],
        bn=bn, bn_decay=bn_decay, is_training=is_training,
        use_bias=True, drop=0.1)
    return tf.squeeze(X, axis=3)


def GMAN_gen(X, TE, SE, Y, T, bn, bn_decay, is_training):
    '''
    GMAN
    X:       [batch_size, P, N]
    TE:      [batch_size, P + Q, 2] (time-of-day, day-of-week)
    SE:      [N, K * d]
    Y:       labels/conditions [batch_size, C] c = # of classes (3)
    P:       number of history steps
    Q:       number of prediction steps
    T:       one day is divided into T steps
    L:       number of STAtt blocks in the encoder/decoder
    K:       number of attention heads
    d:       dimension of each attention head outputs
    return:  [batch_size, Q, N]
    '''

    P = config.AGENT.HISTORY_STEPS
    Q = config.AGENT.PREDICTION_STEPS
    L = config.AGENT.NUMBER_OF_STATT_BLOCKS
    K = config.AGENT.NUMBER_OF_ATTENTION_HEADS
    d = config.AGENT.HEAD_ATTENTION_OUTPUT_DIM

    D = K * d
    # input
    X = tf.expand_dims(X, axis=-1)

    X = tf.concat(X, Y)  # concat the label

    X = FC(
        X, units=[D, D], activations=[tf.nn.relu, None],
        bn=bn, bn_decay=bn_decay, is_training=is_training)
    # STE
    STE = STEmbedding(SE, TE, T, D, bn, bn_decay, is_training)
    STE_P = STE[:, : P]
    STE_Q = STE[:, P:]
    # encoder
    for _ in range(L):
        X = STAttBlock(X, STE_P, K, d, bn, bn_decay, is_training)
    # transAtt
    X = transformAttention(
        X, STE_P, STE_Q, K, d, bn, bn_decay, is_training)
    # decoder
    for _ in range(L):
        X = STAttBlock(X, STE_Q, K, d, bn, bn_decay, is_training)
    # output
    X = FC(
        X, units=[D, 1], activations=[tf.nn.relu, None],
        bn=bn, bn_decay=bn_decay, is_training=is_training,
        use_bias=True, drop=0.1)
    return tf.squeeze(X, axis=3)


def GMAN_disc(X, TE, SE, Y, T, bn, bn_decay, is_training):
    '''
    GMAN
    X:       [batch_size, P, N]
    TE:      [batch_size, P + Q, 2] (time-of-day, day-of-week)
    SE:      [N, K * d]
    Y:       labels/conditions [batch_size, C] c = # of classes (3)
    P:       number of history steps
    Q:       number of prediction steps
    T:       one day is divided into T steps
    L:       number of STAtt blocks in the encoder/decoder
    K:       number of attention heads
    d:       dimension of each attention head outputs
    return:  [batch_size, Q, N]
    '''

    P = config.AGENT.HISTORY_STEPS
    Q = config.AGENT.PREDICTION_STEPS
    L = config.AGENT.NUMBER_OF_STATT_BLOCKS
    K = config.AGENT.NUMBER_OF_ATTENTION_HEADS
    d = config.AGENT.HEAD_ATTENTION_OUTPUT_DIM

    D = K * d
    # input
    X = tf.expand_dims(X, axis=-1)

    X = tf.concat(X, Y)  # concat the label

    X = FC(
        X, units=[D, D], activations=[tf.nn.relu, None],
        bn=bn, bn_decay=bn_decay, is_training=is_training)
    # STE
    STE = STEmbedding(SE, TE, T, D, bn, bn_decay, is_training)
    STE_P = STE[:, : P]
    STE_Q = STE[:, P:]
    # encoder
    for _ in range(L):
        X = STAttBlock(X, STE_P, K, d, bn, bn_decay, is_training)
    # transAtt
    X = transformAttention(
        X, STE_P, STE_Q, K, d, bn, bn_decay, is_training)
    # decoder
    for _ in range(L):
        X = STAttBlock(X, STE_Q, K, d, bn, bn_decay, is_training)
    # output
    X = FC(
        X, units=[D, 1], activations=[tf.nn.relu, None],
        bn=bn, bn_decay=bn_decay, is_training=is_training,
        use_bias=True, drop=0.1)
    return tf.squeeze(X, axis=3)


def mae_loss(pred, label):
    mask = tf.not_equal(label, 0)
    mask = tf.cast(mask, tf.float32)
    mask /= tf.reduce_mean(mask)
    mask = tf.compat.v2.where(
        condition=tf.math.is_nan(mask), x=0., y=mask)
    loss = tf.abs(tf.subtract(pred, label))
    loss *= mask
    loss = tf.compat.v2.where(
        condition=tf.math.is_nan(loss), x=0., y=loss)
    loss = tf.reduce_mean(loss)
    return loss


In [10]:

# log string
def log_string(log, string):
    log.write(string + '\n')
    log.flush()
    print(string)


# metric
def metric(pred, label):
    with np.errstate(divide='ignore', invalid='ignore'):
        mask = np.not_equal(label, 0)
        mask = mask.astype(np.float32)
        mask /= np.mean(mask)
        mae = np.abs(np.subtract(pred, label)).astype(np.float32)
        rmse = np.square(mae)
        mape = np.divide(mae, label)
        mae = np.nan_to_num(mae * mask)
        mae = np.mean(mae)
        rmse = np.nan_to_num(rmse * mask)
        rmse = np.sqrt(np.mean(rmse))
        mape = np.nan_to_num(mape * mask)
        mape = np.mean(mape)
    return mae, rmse, mape


def seq2instance(data, P, Q):
    num_step, dims = data.shape
    num_sample = num_step - P - Q + 1
    x = np.zeros(shape=(num_sample, P, dims))
    y = np.zeros(shape=(num_sample, Q, dims))
    for i in range(num_sample):
        x[i] = data[i: i + P]
        y[i] = data[i + P: i + P + Q]
    return x, y


def loadData(dataset_file, attribute):
    # Traffic
    df = pd.read_hdf(os.path.join(config.ROOT_DIR, dataset_file),
                     key='data').loc[:, attribute]
    Traffic = df.values
    # train/val/test
    num_step = df.shape[0]
    train_steps = round(config.AGENT.TRAIN_RATIO * num_step)
    test_steps = round(config.AGENT.TEST_RATIO * num_step)
    val_steps = num_step - train_steps - test_steps
    train = Traffic[: train_steps]
    val = Traffic[train_steps: train_steps + val_steps]
    test = Traffic[-test_steps:]
    # X, Y
    trainX, trainY = seq2instance(
        train, config.AGENT.HISTORY_STEPS, config.AGENT.PREDICTION_STEPS)
    valX, valY = seq2instance(
        val, config.AGENT.HISTORY_STEPS, config.AGENT.PREDICTION_STEPS)
    testX, testY = seq2instance(
        test, config.AGENT.HISTORY_STEPS, config.AGENT.PREDICTION_STEPS)
    # normalization
    mean, std = np.mean(trainX), np.std(trainX)
    trainX = (trainX - mean) / std
    valX = (valX - mean) / std
    testX = (testX - mean) / std

    se_file = os.path.join(config.ROOT_DIR, config.PATH_TO_RECORDS, 'SE.txt')

    # spatial embedding
    f = open(se_file, mode='r')
    lines = f.readlines()
    temp = lines[0].split(' ')
    N, dims = int(temp[0]), int(temp[1])
    SE = np.zeros(shape=(N, dims), dtype=np.float32)
    for line in lines[1:]:
        temp = line.split(' ')
        index = int(temp[0])
        SE[index] = temp[1:]

    # temporal embedding
    Time = df.index
    dayofweek = np.reshape(Time.weekday, newshape=(-1, 1))
    timeofday = (Time.hour * 3600 + Time.minute * 60 + Time.second) \
        // Time.freq.delta.total_seconds()
    timeofday = np.reshape(timeofday, newshape=(-1, 1))
    Time = np.concatenate((dayofweek, timeofday), axis=-1)
    # train/val/test
    train = Time[: train_steps]
    val = Time[train_steps: train_steps + val_steps]
    test = Time[-test_steps:]
    # shape = (num_sample, P + Q, 2)
    trainTE = seq2instance(train, config.AGENT.HISTORY_STEPS,
                           config.AGENT.PREDICTION_STEPS)
    trainTE = np.concatenate(trainTE, axis=1).astype(np.int32)
    valTE = seq2instance(val, config.AGENT.HISTORY_STEPS,
                         config.AGENT.PREDICTION_STEPS)
    valTE = np.concatenate(valTE, axis=1).astype(np.int32)
    testTE = seq2instance(test, config.AGENT.HISTORY_STEPS,
                          config.AGENT.PREDICTION_STEPS)
    testTE = np.concatenate(testTE, axis=1).astype(np.int32)

    return (trainX, trainTE, trainY, valX, valTE, valY, testX, testTE, testY,
            SE, mean, std)


In [None]:
def train():

    start = time.time()

    dataset_file = os.path.join(config.PATH_TO_RECORDS, 'data', 'dataset.h5')

    log_file = os.path.join(config.PATH_TO_RECORDS,
                            f"{config.EXPERIMENT.SCENARIO_NAME}_gman_log")
    log = open(os.path.join(config.ROOT_DIR, log_file), 'w')

    model_file = os.path.join(config.PATH_TO_RECORDS,
                              f"{config.EXPERIMENT.SCENARIO_NAME}_gman_model")

    # load data
    log_string(log, 'loading data...')
    (trainX, trainTE, trainY, valX, valTE, valY, testX, testTE, testY, SE,
     mean, std) = loadData(dataset_file, config.AGENT.PREDICTED_ATTRIBUTE)
    log_string(log, 'trainX: %s\ttrainY: %s' %
                     (trainX.shape, trainY.shape))
    log_string(log, 'valX:   %s\t\tvalY:   %s' %
                     (valX.shape, valY.shape))
    log_string(log, 'testX:  %s\t\ttestY:  %s' %
                     (testX.shape, testY.shape))
    log_string(log, 'data loaded!')

    # train model
    log_string(log, 'compiling ..')
    T = 24 * 60 // config.AGENT.TIME_SLOT
    num_train, _, N = trainX.shape
    X, TE, label, is_training = placeholder(
        config.AGENT.HISTORY_STEPS, config.AGENT.PREDICTION_STEPS, N)
    global_step = tf.Variable(0, trainable=False)
    bn_momentum = tf.compat.v1.train.exponential_decay(
        0.5, global_step,
        decay_steps=config.AGENT.DECAY_EPOCH * num_train // config.AGENT.BATCH_SIZE,
        decay_rate=0.5, staircase=True)
    bn_decay = tf.minimum(0.99, 1 - bn_momentum)
    pred = GMAN(
        X,
        TE,
        SE,
        T,
        bn=True,
        bn_decay=bn_decay,
        is_training=is_training
    )
    pred = pred * std + mean
    loss = mae_loss(pred, label)
    tf.compat.v1.add_to_collection('pred', pred)
    tf.compat.v1.add_to_collection('loss', loss)
    learning_rate = tf.compat.v1.train.exponential_decay(
        config.AGENT.LEARNING_RATE, global_step,
        decay_steps=config.AGENT.DECAY_EPOCH * num_train // config.AGENT.BATCH_SIZE,
        decay_rate=0.7, staircase=True)
    learning_rate = tf.maximum(learning_rate, 1e-5)
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate)
    train_op = optimizer.minimize(loss, global_step=global_step)
    parameters = 0
    for variable in tf.compat.v1.trainable_variables():
        parameters += np.product([x.value for x in variable.get_shape()])
    log_string(log, 'trainable parameters: {:,}'.format(parameters))
    log_string(log, 'model compiled!')
    saver = tf.compat.v1.train.Saver()
    tf_config = tf.compat.v1.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    sess = tf.compat.v1.Session(config=tf_config)
    sess.run(tf.compat.v1.global_variables_initializer())
    log_string(log, '**** training model ****')
    num_val = valX.shape[0]
    wait = 0
    val_loss_min = np.inf
    for epoch in range(config.AGENT.MAX_EPOCH):
        if wait >= config.AGENT.PATIENCE:
            log_string(log, 'early stop at epoch: %04d' % (epoch))
            break
        # shuffle
        permutation = np.random.permutation(num_train)
        trainX = trainX[permutation]
        trainTE = trainTE[permutation]
        trainY = trainY[permutation]
        # train loss
        start_train = time.time()
        train_loss = 0
        num_batch = math.ceil(num_train / config.AGENT.BATCH_SIZE)
        for batch_idx in range(num_batch):
            start_idx = batch_idx * config.AGENT.BATCH_SIZE
            end_idx = min(num_train, (batch_idx + 1) * config.AGENT.BATCH_SIZE)
            feed_dict = {
                X: trainX[start_idx:end_idx],
                TE: trainTE[start_idx:end_idx],
                label: trainY[start_idx:end_idx],
                is_training: True}
            _, loss_batch = sess.run([train_op, loss], feed_dict=feed_dict)
            train_loss += loss_batch * (end_idx - start_idx)
        train_loss /= num_train
        end_train = time.time()
        # val loss
        start_val = time.time()
        val_loss = 0
        num_batch = math.ceil(num_val / config.AGENT.BATCH_SIZE)
        for batch_idx in range(num_batch):
            start_idx = batch_idx * config.AGENT.BATCH_SIZE
            end_idx = min(num_val, (batch_idx + 1) * config.AGENT.BATCH_SIZE)
            feed_dict = {
                X: valX[start_idx:end_idx],
                TE: valTE[start_idx:end_idx],
                label: valY[start_idx:end_idx],
                is_training: False}
            loss_batch = sess.run(loss, feed_dict=feed_dict)
            val_loss += loss_batch * (end_idx - start_idx)
        val_loss /= num_val
        end_val = time.time()
        log_string(
            log,
            '%s | epoch: %04d/%d, training time: %.1fs, inference time: %.1fs' %
            (datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), epoch + 1,
             config.AGENT.MAX_EPOCH, end_train - start_train, end_val - start_val))
        log_string(
            log, 'train loss: %.4f, val_loss: %.4f' % (train_loss, val_loss))
        if val_loss <= val_loss_min:
            log_string(
                log,
                'val loss decrease from %.4f to %.4f, saving model to %s' %
                (val_loss_min, val_loss, os.path.join(config.ROOT_DIR, model_file)))
            wait = 0
            val_loss_min = val_loss
            saver.save(sess, os.path.join(config.ROOT_DIR, model_file))
        else:
            wait += 1

    # test model
    log_string(log, '**** testing model ****')
    log_string(log, 'loading model from %s' %
                     os.path.join(config.ROOT_DIR, model_file))
    saver = tf.compat.v1.train.import_meta_graph(
        os.path.join(config.ROOT_DIR, model_file) + '.meta')
    saver.restore(sess, os.path.join(config.ROOT_DIR, model_file))
    log_string(log, 'model restored!')
    log_string(log, 'evaluating...')
    num_test = testX.shape[0]
    trainPred = []
    num_batch = math.ceil(num_train / config.AGENT.BATCH_SIZE)
    for batch_idx in range(num_batch):
        start_idx = batch_idx * config.AGENT.BATCH_SIZE
        end_idx = min(num_train, (batch_idx + 1) * config.AGENT.BATCH_SIZE)
        feed_dict = {
            X: trainX[start_idx: end_idx],
            TE: trainTE[start_idx: end_idx],
            is_training: False}
        pred_batch = sess.run(pred, feed_dict=feed_dict)
        trainPred.append(pred_batch)
    trainPred = np.concatenate(trainPred, axis=0)
    valPred = []
    num_batch = math.ceil(num_val / config.AGENT.BATCH_SIZE)
    for batch_idx in range(num_batch):
        start_idx = batch_idx * config.AGENT.BATCH_SIZE
        end_idx = min(num_val, (batch_idx + 1) * config.AGENT.BATCH_SIZE)
        feed_dict = {
            X: valX[start_idx: end_idx],
            TE: valTE[start_idx: end_idx],
            is_training: False}
        pred_batch = sess.run(pred, feed_dict=feed_dict)
        valPred.append(pred_batch)
    valPred = np.concatenate(valPred, axis=0)
    testPred = []
    num_batch = math.ceil(num_test / config.AGENT.BATCH_SIZE)
    start_test = time.time()
    for batch_idx in range(num_batch):
        start_idx = batch_idx * config.AGENT.BATCH_SIZE
        end_idx = min(num_test, (batch_idx + 1) * config.AGENT.BATCH_SIZE)
        feed_dict = {
            X: testX[start_idx:end_idx],
            TE: testTE[start_idx:end_idx],
            is_training: False}
        pred_batch = sess.run(pred, feed_dict=feed_dict)
        testPred.append(pred_batch)
    end_test = time.time()
    testPred = np.concatenate(testPred, axis=0)
    train_mae, train_rmse, train_mape = metric(trainPred, trainY)
    val_mae, val_rmse, val_mape = metric(valPred, valY)
    test_mae, test_rmse, test_mape = metric(testPred, testY)
    log_string(log, 'testing time: %.1fs' % (end_test - start_test))
    log_string(log, '                MAE\t\tRMSE\t\tMAPE')
    log_string(log, 'train            %.2f\t\t%.2f\t\t%.2f%%' %
                     (train_mae, train_rmse, train_mape * 100))
    log_string(log, 'val              %.2f\t\t%.2f\t\t%.2f%%' %
                     (val_mae, val_rmse, val_mape * 100))
    log_string(log, 'test             %.2f\t\t%.2f\t\t%.2f%%' %
                     (test_mae, test_rmse, test_mape * 100))
    log_string(log, 'performance in each prediction step')
    MAE, RMSE, MAPE = [], [], []
    for q in range(config.AGENT.PREDICTION_STEPS):
        mae, rmse, mape = metric(testPred[:, q], testY[:, q])
        MAE.append(mae)
        RMSE.append(rmse)
        MAPE.append(mape)
        log_string(log, 'step: %02d         %.2f\t\t%.2f\t\t%.2f%%' %
                         (q + 1, mae, rmse, mape * 100))
    average_mae = np.mean(MAE)
    average_rmse = np.mean(RMSE)
    average_mape = np.mean(MAPE)
    log_string(
        log, 'average:         %.2f\t\t%.2f\t\t%.2f%%' %
             (average_mae, average_rmse, average_mape * 100))
    end = time.time()
    log_string(log, 'total time: %.1fmin' % ((end - start) / 60))
    sess.close()
    log.close()
