In [12]:
import tensorflow as tf
import numpy as np
from configs import ParseParams
from DataGenerator import DataGenerator
from env import Env

In [13]:
class AttentionVRPCritic(object):
    """A generic attention module for the attention in vrp model"""
    def __init__(self, dim, use_tanh=False, C=10,_name='Attention',_scope=''):

        self.use_tanh = use_tanh
        self._scope = _scope

        with tf.compat.v1.variable_scope(_scope+_name):
            # self.v: is a variable with shape [1 x dim]
            self.v = tf.compat.v1.get_variable('v',[1,dim],
                       initializer=tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
            self.v = tf.expand_dims(self.v,2)
            
        self.emb_d = tf.compat.v1.layers.Conv1D(dim,1,_scope=_scope+_name +'/emb_d') #conv1d
        self.project_d = tf.compat.v1.layers.Conv1D(dim,1,_scope=_scope+_name +'/proj_d') #conv1d_1
        
        self.project_query = tf.compat.v1.layers.Dense(dim,_scope=_scope+_name +'/proj_q') #
        self.project_ref = tf.compat.v1.layers.Conv1D(dim,1,_scope=_scope+_name +'/proj_e') #conv1d_2

        self.C = C  # tanh exploration parameter
        self.tanh = tf.nn.tanh
        
    def __call__(self, query, ref, env):
        """
        This function gets a query tensor and ref rensor and returns the logit op.
        Args: 
            query: is the hidden state of the decoder at the current
                time step. [batch_size x dim]
            ref: the set of hidden states from the encoder. 
                [batch_size x max_time x dim]

            env: keeps demand ond load values and help decoding. Also it includes mask.
                env.mask: a matrix used for masking the logits and glimpses. It is with shape
                         [batch_size x max_time]. Zeros in this matrix means not-masked nodes. Any 
                         positive number in this mask means that the node cannot be selected as next 
                         decision point.
                env.demands: a list of demands which changes over time.

        Returns:
            e: convolved ref with shape [batch_size x max_time x dim]
            logits: [batch_size x max_time]
        """
        # we need the first demand value for the critic
        demand = env.input_data[:,:,-1]
        max_time = tf.shape(input=demand)[1]

        # embed demand and project it
        # emb_d:[batch_size x max_time x dim ]
        emb_d = self.emb_d(tf.expand_dims(demand,2))
        # d:[batch_size x max_time x dim ]
        d = self.project_d(emb_d)


        # expanded_q,e: [batch_size x max_time x dim]
        e = self.project_ref(ref)
        q = self.project_query(query) #[batch_size x dim]
        expanded_q = tf.tile(tf.expand_dims(q,1),[1,max_time,1])

        # v_view:[batch_size x dim x 1]
        v_view = tf.tile( self.v, [tf.shape(input=e)[0],1,1]) 
        
        # u : [batch_size x max_time x dim] * [batch_size x dim x 1] = 
        #       [batch_size x max_time]
        u = tf.squeeze(tf.matmul(self.tanh(expanded_q + e + d), v_view),2)

        if self.use_tanh:
            logits = self.C * self.tanh(u)
        else:
            logits = u  

        return e, logits

In [14]:
#tf.compat.v1.disable_eager_execution()
args, prt = ParseParams()
batch_size = 2
nodes = 5
cust = nodes-1
args['batch_size'] = batch_size
args['n_nodes'] = nodes
args['n_cust'] = cust
env = Env( args)

node_num = args['n_nodes']
emb_dim = 30
lstm_dim = args['hidden_dim']
batch_num = 2
process = AttentionVRPCritic( lstm_dim, "P")


Created train iterator.
Loading dataset for vrp-size-1000-len-5-test.txt...


In [15]:
ref = tf.cast( np.random.randn( batch_num, node_num, emb_dim), tf.float32 )
query = tf.cast( np.random.randn(batch_num, lstm_dim ), tf.float32 )

o1, logits = process( query, ref, env )
print( logits.numpy())
logprob = tf.nn.log_softmax(logits)
print( logprob.numpy())
prob = tf.exp(logprob)
print(prob.numpy())
idx = tf.expand_dims(tf.argmax(input=prob, axis=1),1)
print( idx.numpy())

[[9.419918  8.952817  9.180444  7.7896266 7.4645944]
 [8.766563  9.583124  9.295224  9.779091  8.14705  ]]
[[-1.0120552 -1.4791563 -1.2515295 -2.6423466 -2.9673789]
 [-2.1102276 -1.2936668 -1.5815668 -1.0977001 -2.729741 ]]
[[0.3634712  0.22782984 0.28606695 0.07119401 0.05143796]
 [0.12121038 0.27426326 0.20565262 0.33363754 0.06523618]]
[[0]
 [3]]
