# Modules: Datasets and Graphs

In [1]:
import tensorflow as tf
import tensorflow as tf
from tensorflow.keras.layers import Layer

from msp.datasets import make_data, load_sample_data, make_sparse_data
from msp.graphs import MSPGraph, MSPSparseGraph
from msp.models.encoders import GGCNEncoder
from msp.models.decoders import AttentionDecoder

%load_ext autoreload
%autoreload 2

In [2]:
dataset = make_sparse_data(40, msp_size=(1,2), random_state=2021)

In [3]:
batch_size = 2
batch_dataset = dataset.batch(batch_size)
inputs = list(batch_dataset.take(1))[0]
inputs

MSPSparseGraph(adj_matrix=<tf.Tensor: shape=(2, 3, 3), dtype=uint8, numpy=
array([[[0, 1, 1],
        [1, 0, 1],
        [1, 1, 0]],

       [[0, 1, 1],
        [1, 0, 1],
        [1, 1, 0]]], dtype=uint8)>, node_features=<tf.Tensor: shape=(2, 3, 5), dtype=float64, numpy=
array([[[0.60597828, 0.73336936, 0.13894716, 0.31267308, 1.        ],
        [0.12816238, 0.17899311, 0.75292543, 0.66216051, 1.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ]],

       [[0.82309863, 0.73222503, 0.06905627, 0.67212894, 1.        ],
        [0.82801437, 0.20446939, 0.61748895, 0.61770101, 1.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ]]])>, edge_features=<tf.Tensor: shape=(2, 3, 3, 3), dtype=float64, numpy=
array([[[[0.        , 0.        , 0.        ],
         [0.57430428, 0.37116084, 0.        ],
         [0.        , 0.        , 1.        ]],

        [[0.57430428, 0.37116084, 0.        ],
         [0.        , 0.        , 0.       

# Modules: Layers and Encoders

In [4]:
units = 6
layers = 2
ecoder_model = GGCNEncoder(units, layers)
embedded_inputs = ecoder_model(inputs)

In [5]:
embedded_inputs.node_embed

<tf.Tensor: shape=(2, 3, 6), dtype=float32, numpy=
array([[[ 0.615415  ,  1.2802217 ,  0.99935806, -0.5541735 ,
          3.416307  ,  1.355025  ],
        [ 1.1635169 , -0.16933191,  0.7606487 , -0.3814553 ,
          1.1589245 ,  0.81719756],
        [ 0.        ,  0.01952665,  0.        ,  0.        ,
          0.43814158,  1.0351074 ]],

       [[ 0.832458  ,  0.50137305,  1.155561  , -0.44193715,
          3.733396  ,  1.3800619 ],
        [ 1.3427941 , -0.11154047,  1.1022799 , -0.15452659,
          2.8134623 ,  0.7567066 ],
        [ 0.04526918,  0.        ,  0.        ,  0.        ,
          0.7266619 ,  0.9510701 ]]], dtype=float32)>

In [6]:
embedded_inputs.edge_embed

<tf.Tensor: shape=(2, 3, 3, 6), dtype=float32, numpy=
array([[[[ 1.46499956e+00,  5.61989486e-01,  0.00000000e+00,
           8.34677815e-01,  1.68942082e+00,  3.30655384e+00],
         [ 5.46871126e-03,  8.56781602e-01, -8.03075433e-02,
           8.83548975e-01,  1.39451098e+00,  2.42211223e+00],
         [-1.17827654e-01,  1.21429563e-04,  7.58905053e-01,
           4.72448260e-01,  9.66189981e-01,  1.62234616e+00]],

        [[ 1.93997490e+00,  6.81192160e-01, -8.03075433e-02,
           5.46523690e-01,  1.42587686e+00,  2.64375997e+00],
         [ 9.12253916e-01,  3.94608229e-01,  0.00000000e+00,
           4.11679834e-01,  1.47785544e+00,  2.36596799e+00],
         [ 5.65907955e-01, -1.71363965e-01,  7.58905053e-01,
          -3.51288259e-01,  6.02509797e-01,  1.19404054e+00]],

        [[ 2.28166389e+00,  4.95497584e-01,  7.58905053e-01,
           1.16787457e+00, -2.98210740e-01,  3.38283730e+00],
         [ 1.03474855e+00,  4.99601722e-01,  7.58905053e-01,
           7.7197516

# Modules: Decoder

In [7]:
model = AttentionDecoder(units, aggregation_graph='mean', n_heads=3)
makespan, ll, pi = model(embedded_inputs)
pi

Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.


<tf.Tensor: shape=(2, 3), dtype=int64, numpy=
array([[0, 2, 1],
       [0, 1, 2]])>

# All together

In [8]:
n_instances = 640
instance_size = (2, 8)
batch_size = 64
units = 64  # embedding dims
layers = 2

dataset = make_sparse_data(n_instances, msp_size=instance_size, random_state=2021)
batch_dataset = dataset.batch(batch_size)
inputs = list(batch_dataset.take(1))[0]

ecoder_model = GGCNEncoder(units, layers)
embedded_inputs = ecoder_model(inputs)

n_heads = 8
model = AttentionDecoder(units, aggregation_graph='mean', n_heads=n_heads)
makespan, ll, pi = model(embedded_inputs)
pi

<tf.Tensor: shape=(64, 10), dtype=int64, numpy=
array([[2, 8, 7, 3, 1, 4, 9, 0, 6, 5],
       [1, 9, 2, 5, 6, 4, 7, 3, 8, 0],
       [6, 4, 3, 9, 7, 8, 5, 2, 1, 0],
       [2, 3, 7, 6, 8, 9, 1, 5, 4, 0],
       [4, 8, 3, 2, 1, 7, 9, 6, 0, 5],
       [5, 4, 9, 7, 8, 2, 3, 1, 6, 0],
       [9, 5, 6, 7, 4, 8, 3, 2, 0, 1],
       [9, 3, 8, 1, 2, 7, 5, 4, 6, 0],
       [5, 9, 2, 7, 6, 8, 3, 0, 1, 4],
       [5, 7, 4, 1, 3, 6, 2, 8, 9, 0],
       [3, 7, 8, 9, 1, 6, 2, 4, 5, 0],
       [8, 5, 2, 3, 9, 6, 7, 4, 1, 0],
       [5, 3, 1, 2, 6, 9, 7, 8, 4, 0],
       [2, 9, 1, 4, 7, 6, 5, 3, 8, 0],
       [6, 7, 3, 5, 8, 9, 2, 1, 4, 0],
       [6, 2, 7, 9, 5, 1, 3, 8, 4, 0],
       [6, 3, 4, 1, 9, 2, 7, 8, 5, 0],
       [7, 2, 9, 6, 1, 3, 4, 8, 5, 0],
       [2, 6, 5, 9, 3, 7, 8, 1, 4, 0],
       [6, 5, 2, 4, 7, 8, 9, 3, 1, 0],
       [0, 3, 1, 7, 8, 9, 5, 6, 4, 2],
       [7, 1, 9, 6, 4, 3, 2, 8, 5, 0],
       [5, 8, 4, 3, 1, 7, 9, 2, 6, 0],
       [1, 5, 6, 3, 9, 7, 2, 4, 8, 0],
       [0, 8, 9,

In [9]:
inputs.adj_matrix[0]


tf.math.reduce_all(tf.math.equal(inputs.adj_matrix[1], inputs.adj_matrix[2]))

<tf.Tensor: shape=(), dtype=bool, numpy=False>

In [10]:
import sys
import numpy
numpy.set_printoptions(threshold=sys.maxsize)

print(inputs.adj_matrix[0].numpy())

[[0 1 0 1 1 0 0 1 1 1]
 [1 0 1 1 1 1 1 1 1 1]
 [0 1 0 1 1 1 1 1 1 1]
 [1 1 1 0 1 1 1 1 1 1]
 [1 1 1 1 0 1 1 1 1 1]
 [0 1 1 1 1 0 1 1 1 1]
 [0 1 1 1 1 1 0 1 1 1]
 [1 1 1 1 1 1 1 0 1 1]
 [1 1 1 1 1 1 1 1 0 0]
 [1 1 1 1 1 1 1 1 0 0]]


In [12]:
print(inputs.adj_matrix[7].numpy())

[[0 1 1 1 1 1 1 1 1 1]
 [1 0 0 1 0 1 0 1 1 1]
 [1 0 0 1 1 1 1 1 1 1]
 [1 1 1 0 1 1 1 1 1 1]
 [1 0 1 1 0 1 1 1 1 1]
 [1 1 1 1 1 0 1 1 1 1]
 [1 0 1 1 1 1 0 1 1 1]
 [1 1 1 1 1 1 1 0 1 1]
 [1 1 1 1 1 1 1 1 0 0]
 [1 1 1 1 1 1 1 1 0 0]]


In [7]:
from typing import NamedTuple

from tensorflow import Tensor

class MSPState:

    def __init__(self, inputs):
        """ """
        self.adj_matrix = inputs.adj_matrix
        self.node_embed = inputs.node_embed
        
        batch_size, num_nodes, node_embed_dims = self.node_embed.shape
        
        self._first_node = tf.zeros((batch_size,1), dtype=tf.int64)
        self._last_node = self._first_node
        self._visited = tf.zeros((batch_size,1,num_nodes), dtype=tf.int64)
        self._makespan = tf.zeros((batch_size,1))

        self.i = tf.zeros(1, dtype=tf.int64) # # Vector with length num_steps
        self.ids = tf.range(5, delta=1, dtype=tf.int64)[:, None] #  # Add steps dimension
        #self._step_num = tf.zeros(1, dtype=tf.int64)


    @property
    def first_node(self):
        return self._first_node

    


    

        # self._node_embed = None
        # self._edge_embed = None
        # self._adj_matrix = None
        # self._first_job = None
        # self._last_job = None

    # @property
    # def node_embed(self):
    #     return self._node_embed

    # @node_embed.setter
    # def node_embed(self, node_embed):
    #     self._node_embed = node_embed

    # @property
    # def adj_matrix(self):
    #     return self._adj_matrix

    # @adj_matrix.setter
    # def adj_matrix(self, adj_matrix):
    #     self._adj_matrix = adj_matrix

    # def initialize(self, inputs):
    #     self.adj_matrix = inputs.adj_matrix
    #    return self




class MSPSolution:

    def __init__(self):
        pass

    # sequence: Tensor
    # makespan: Tensor
    # a: int

In [8]:
import torch
torch.arange(5, dtype=torch.int64)[:, None], tf.range(5, delta=1, dtype=tf.int64)[:, None]

(tensor([[0],
         [1],
         [2],
         [3],
         [4]]), <tf.Tensor: shape=(5, 1), dtype=int64, numpy=
 array([[0],
        [1],
        [2],
        [3],
        [4]])>)

In [9]:
class MSPSparseGraph2(NamedTuple):
    """ """
    adj_matrix: Tensor
    node_features: Tensor
    edge_features: Tensor
    alpha: Tensor

    @property
    def num_node(self):
        return self.adj_matrix.shape[-2]

In [10]:
from collections import namedtuple

class A(NamedTuple):

    a: int

    @property
    def num(self):
        return self.a

    # jane._replace(a=26)


class C(
    NamedTuple(
        'SSSSS',
        A._field_types.items()
    ),
    A
):
    pass



class B(
    NamedTuple(
        'SSSB',
        [
            *A._field_types.items(),
            ('weight', int)
        ]
    ),
    A
):
    @property
    def h(self):
        return self.weight


# class B(namedtuple('SSSB', [*A._fields, "weight"]), A):
#     #[*A._fields, "weight"]

#     @property
#     def h(self):
#         return self.weight

    

    # def __getitem__(self, key):
    #     return self[key]

In [11]:
#B(A(a=90)._asdict(), weight=90)
#B(**A(a=9090)._asdict(), weight=90).num
#Ba = 2343, weight=23).num
#dir(B)
#C(a=90)
B(*(34,), weight=34)

B(a=34, weight=34)

In [12]:
list(A._field_types.items())

[('a', int)]

In [14]:
A._field_types

OrderedDict([('a', int)])

In [15]:
def fun(a):
    print(a)

fun(**inputs._asdict())


TypeError: fun() got an unexpected keyword argument 'adj_matrix'

In [16]:
[*A._field_types, ('v','d')]

['a', ('v', 'd')]

In [17]:
# for attr, val in inputs._asdict().items():
#     print(attr, val)

In [18]:
from typing import NamedTuple
class AttentionModelFixed(NamedTuple):
    """
    Context for AttentionModel decoder that is fixed during decoding so can be precomputed/cached
    This class allows for efficient indexing of multiple Tensors at once
    """
    node_embeddings: torch.Tensor
    context_node_projected: torch.Tensor
    glimpse_key: torch.Tensor
    glimpse_val: torch.Tensor
    logit_key: torch.Tensor

    def __getitem__(self, key):
        if tf.is_tensor(key) or isinstance(key, slice):
            return AttentionModelFixed(
                node_embeddings=self.node_embeddings[key],
                context_node_projected=self.context_node_projected[key],
                glimpse_key=self.glimpse_key[:, key],  # dim 0 are the heads
                glimpse_val=self.glimpse_val[:, key],  # dim 0 are the heads
                logit_key=self.logit_key[key]
            )
        return super(AttentionModelFixed, self).__getitem__(key)

  


False

In [236]:
"""

"""
import math

import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras import Model

from msp.layers import GGCNLayer
from msp.graphs import MSPSparseGraph
from msp.solutions import MSPState

class AttentionDecoder(Model):

    def __init__(self, 
                 units,
                 *args, 
                 activation='relu',
                 aggregation_graph='mean',
                 n_heads=2, # make it 8
                 mask_inner=True,
                 tanh_clipping=10,
                 decode_type='sampling',
                 extra_logging=False,
                 **kwargs):
        """ """
        super(AttentionDecoder, self).__init__(*args, **kwargs)
        self.aggregation_graph = aggregation_graph
        self.n_heads = n_heads
        self.mask_inner = mask_inner
        self.tanh_clipping = tanh_clipping
        self.decode_type = decode_type
        self.extra_logging = extra_logging

        embedding_dim = units
        
        self.W_placeholder = self.add_weight(shape=(2*embedding_dim,),
                                initializer='random_uniform', #Placeholder should be in range of activations (think)
                                name='W_placeholder',
                                trainable=True)

        graph_embed_shape = tf.TensorShape((None, units))
        self.fixed_context_layer = tf.keras.layers.Dense(units, use_bias=False)
        self.fixed_context_layer.build(graph_embed_shape)

        # For each node we compute (glimpse key, glimpse value, logit key) so 3 * embedding_dim
        project_node_embeddings_shape = tf.TensorShape((None, None, None, units))
        self.project_node_embeddings = tf.keras.layers.Dense(3*units, use_bias=False)
        self.project_node_embeddings.build(project_node_embeddings_shape)

        #
        # Embedding of first and last node
        step_context_dim = 2*units
        project_step_context_shape = tf.TensorShape((None, None, step_context_dim))
        self.project_step_context = tf.keras.layers.Dense(embedding_dim, use_bias=False)
        self.project_step_context.build(project_step_context_shape)

        assert embedding_dim % n_heads == 0
        # Note n_heads * val_dim == embedding_dim so input to project_out is embedding_dim

        project_out_shape = tf.TensorShape((None, None, 1, embedding_dim))
        self.project_out = tf.keras.layers.Dense(embedding_dim, use_bias=False)
        self.project_out.build(project_out_shape)


        # self.context_layer = tf.keras.layers.Dense(units, use_bias=False)
        # self.mha_layer = None
        

        # dynamic router


        
    def call(self, inputs, training=False, return_pi=False):
        """ """
        state = MSPState(inputs)

        node_embedding = inputs.node_embed

        # AttentionModelFixed(node_embedding, fixed_context, *fixed_attention_node_data)
        fixed = self._precompute(node_embedding)


        # for i in range(num_steps):
            # i == 0 should be machine 
            # AttentionCell(inputs, states)

        outputs = []
        sequences = []

        # i = 0
        while not state.all_finished():
            # B x 1 x V
            # Get log probabilities of next action
            log_p, mask = self._get_log_p(fixed, state)
            
            selected = self._select_node(
                    tf.squeeze(tf.exp(log_p), axis=-2), tf.squeeze(mask, axis=-2)) # Squeeze out steps dimension

            state.update(selected)

            outputs.append(log_p[:, 0, :])
            sequences.append(selected)
            
            # if i == 1:
            #     break
            # i+=1
        
        _log_p, pi = tf.stack(outputs, axis=1), tf.stack(sequences, axis=1)

        if self.extra_logging:
            self.log_p_batch = _log_p
            self.log_p_sel_batch = tf.gather(tf.squeeze(_log_p,axis=-2), pi, batch_dims=1)

        # # Get predicted costs
        # cost, mask = self.problem.get_costs(nodes, pi)
        mask = None

        ###################################################
        # Need Clarity #############################################################
        # loglikelihood 
        ll = self._calc_log_likelihood(_log_p, pi, mask)

        ## Just for checking
        return_pi = True    
        if return_pi:
            return state.makespan, ll, pi

        return state.makespan, ll

        

        



    def _precompute(self, node_embedding, num_steps=1):

        graph_embed = self._get_graph_embed(node_embedding)

        fixed_context = self.fixed_context_layer(graph_embed)
        # fixed context = (batch_size, 1, embed_dim) to make broadcastable with parallel timesteps
        fixed_context = tf.expand_dims(fixed_context, axis=-2)

        glimpse_key_fixed, glimpse_val_fixed, logit_key_fixed  = tf.split(
            self.project_node_embeddings(tf.expand_dims(node_embedding, axis=-3)),
            num_or_size_splits=3,
            axis=-1
        )

        # No need to rearrange key for logit as there is a single head
        fixed_attention_node_data = (
            self._make_heads(glimpse_key_fixed, num_steps),
            self._make_heads(glimpse_val_fixed, num_steps),
            logit_key_fixed
        )
        return AttentionModelFixed(node_embedding, fixed_context, *fixed_attention_node_data)

    def _get_graph_embed(self, node_embedding):
        """ """
        if self.aggregation_graph == "sum":
            graph_embed = tf.reduce_sum(node_embedding, axis=-2)
        elif self.aggregation_graph == "max":
            graph_embed = tf.reduce_max(node_embedding, axis=-2)
        elif self.aggregation_graph == "mean":
            graph_embed = tf.reduce_mean(node_embedding, axis=-2)
        else:  # Default: dissable graph embedding
            graph_embed = tf.reduce_sum(node_embedding, axis=-2) * 0.0

        return graph_embed

    def _make_heads(self, v, num_steps=None):

        assert num_steps is None or v.shape[1] == 1 or v.shape[1] == num_steps
        batch_size, _, num_nodes, embed_dims = v.shape
        num_steps = num_steps if num_steps else 1
        head_dims = embed_dims//self.n_heads

        # M x B x N x V x (H/M)
        return tf.transpose(
            tf.broadcast_to(
                tf.reshape(v, shape=[batch_size, v.shape[1], num_nodes, self.n_heads, head_dims]),
                shape=[batch_size, num_steps, num_nodes, self.n_heads, head_dims]
            ),
            perm=[3, 0, 1, 2, 4]
        )

    def _get_log_p(self, fixed, state, normalize=True):
        # Compute query = context node embedding
        
        # B x 1 x H
        query = fixed.context_node_projected + \
                self.project_step_context(self._get_parallel_step_context(fixed.node_embeddings, state))
        
        # Compute keys and values for the nodes
        glimpse_K, glimpse_V, logit_K = self._get_attention_node_data(fixed, state)

        # Compute the mask, for masking next action based on previous actions
        mask = state.get_mask()
        graph_mask = state.get_graph_mask()

        # Compute logits (unnormalized log_p)
        log_p, glimpse = self._one_to_many_logits(query, glimpse_K, glimpse_V, logit_K, mask, graph_mask)

        # [B x N x V]
        # log-softmax activation function so that we get log probabilities over actions
        if normalize:
            log_p = tf.nn.log_softmax(log_p/1.0, axis=-1)

        assert not tf.reduce_any(tf.math.is_nan(log_p)), "Log probabilities over the nodes should be defined"

        return log_p, mask


    def _get_parallel_step_context(self, node_embedding, state, from_depot=False):
        """
        Returns the context per step, optionally for multiple steps at once 
        (for efficient evaluation of the model)
        """
        # last_node at time t
        last_node = state.get_current_node()
        batch_size, num_steps = last_node.shape

        if num_steps == 1:  # We need to special case if we have only 1 step, may be the first or not
            if state.i.numpy()[0] == 0:
                # First and only step, ignore prev_a (this is a placeholder)
                # B x 1 x 2H
                return tf.broadcast_to(self.W_placeholder[None, None, :], 
                                       shape=[batch_size, 1, self.W_placeholder.shape[-1]])
            else:
                return tf.concat(
                    [
                        tf.gather(node_embedding,state.first_node,batch_dims=1), 
                        tf.gather(node_embedding,last_node,batch_dims=1)  
                    ],
                    axis=-1
                )
                
                # print('$'*20)
                # node_embedding = torch.from_numpy(node_embedding.numpy())
                # f = torch.from_numpy(state.first_node.numpy())
                # l = torch.from_numpy(state.last_node.numpy())
                # GG = node_embedding\
                # .gather(
                #     1,
                #     torch.cat((f, l), 1)[:, :, None].expand(batch_size, 2, node_embedding.size(-1))
                # ).view(batch_size, 1, -1)
                # print(GG)
                # print('$'*20)
                # ##############################################
                # # PENDING
                # ##############################################
                # pass

    def _get_attention_node_data(self, fixed, state):
        return fixed.glimpse_key, fixed.glimpse_val, fixed.logit_key

    def _one_to_many_logits(self, query, glimpse_K, glimpse_V, logit_K, mask, graph_mask=None):
        batch_size, num_steps, embed_dim = query.shape
        query_size = key_size = val_size = embed_dim // self.n_heads

        # M x B x N x 1 x (H/M)
        # Compute the glimpse, rearrange dimensions to (n_heads, batch_size, num_steps, 1, key_size)
        glimpse_Q = tf.transpose(
            tf.reshape(
                query, # B x 1 x H
                shape=[batch_size, num_steps, self.n_heads, 1, query_size]
            ),
            perm=[2, 0, 1, 3, 4]
        )

        # [M x B x N x 1 x (H/M)] X [M x B x N x (H/M) x V] = [M x B x N x 1 x V]
        # Batch matrix multiplication to compute compatibilities (n_heads, batch_size, num_steps, graph_size)
        compatibility = tf.matmul(glimpse_Q, tf.transpose(glimpse_K, [0,1,2,4,3])) / math.sqrt(query_size)
        
        mask_temp = tf.cast(tf.broadcast_to(mask[None, :, :, None, :], shape=compatibility.shape), dtype=tf.double)
        compatibility = tf.cast(compatibility, dtype=tf.double) + (mask_temp * -1e9)

        graph_mask_temp = tf.cast(tf.broadcast_to(graph_mask[None, :, :, None, :], shape=compatibility.shape), dtype=tf.double)
        compatibility = tf.cast(compatibility, dtype=tf.double) + (graph_mask_temp * -1e9)

        compatibility = tf.cast(compatibility, dtype=tf.float32)        

        # compatibility[tf.broadcast_to(mask[None, :, :, None, :], shape=compatibility.shape)] = -1e10
        # compatibility[tf.broadcast_to(graph_mask[None, :, :, None, :], shape=compatibility.shape)] = -1e10

        # attention weights a(c,j): 
        attention_weights = tf.nn.softmax(compatibility, axis=-1)


        # [M x B x N x 1 x V] x [M x B x N x V x (H/M)] = [M x B x N x 1 x (H/M)]
        heads = tf.matmul(attention_weights, glimpse_V)
       
        # B x N x 1 x H
        # Project to get glimpse/updated context node embedding (batch_size, num_steps, embedding_dim)
        glimpse = self.project_out(
            tf.reshape(
                tf.transpose(heads, perm=[1, 2, 3, 0, 4]),
                shape=[batch_size, num_steps, 1, self.n_heads*val_size]
            )
        )

        # B x N x 1 x H
        # Now projecting the glimpse is not needed since this can be absorbed into project_out
        # final_Q = self.project_glimpse(glimpse)
        final_Q = glimpse


        # [B x N x 1 x H] x [B x 1 x H x V] = [B x N x 1 x V] --> [B x N x V] (Squeeze) 
        # Batch matrix multiplication to compute logits (batch_size, num_steps, graph_size)
        # logits = 'compatibility'
        logits = tf.squeeze(tf.matmul(final_Q, tf.transpose(logit_K, perm=[0,1,3,2])),
                            axis=-2) / math.sqrt(final_Q.shape[-1])

        logits = logits + ( tf.cast(graph_mask, dtype=tf.float32) * -1e9)
        logits = tf.math.tanh(logits) * self.tanh_clipping
        logits = logits + ( tf.cast(mask, dtype=tf.float32) * -1e9)

        # logits[graph_mask] = -1e10 
        # logits = torch.tanh(logits) * self.tanh_clipping
        # logits[mask] = -1e10
        
        return logits, tf.squeeze(glimpse, axis=-2)

    
    def _select_node(self, probs, mask):
        assert tf.reduce_all(probs == probs) == True, "Probs should not contain any nans"

        if self.decode_type == "greedy":
            selected = tf.math.argmax(probs, axis=1)
            assert not tf.reduce_any(tf.cast(tf.gather_nd(mask, tf.expand_dims(selected, axis=-1), batch_dims=1), dtype=tf.bool)), "Decode greedy: infeasible action has maximum probability"

        elif self.decode_type == "sampling":
            dist = tfp.distributions.Multinomial(total_count=1, probs=probs)
            selected = tf.argmax(dist.sample(), axis=1)

            # Check if sampling went OK
            while tf.reduce_any(tf.cast(tf.gather_nd(mask, tf.expand_dims(selected, axis=-1), batch_dims=1), dtype=tf.bool)):
                print('Sampled bad values, resampling!')
                selected = tf.argmax(dist.sample(), axis=1)

        else:
            assert False, "Unknown decode type"
        return selected

    
    def _calc_log_likelihood(self, _log_p, a, mask):
        
        # Get log_p corresponding to selected actions
        batch_size, steps_count = a.shape
        indices = tf.concat([
        tf.expand_dims(tf.broadcast_to(tf.range(steps_count, dtype=tf.int64), shape=a.shape), axis=-1),
        tf.expand_dims(a, axis=-1)],
        axis=-1
        )
        log_p = tf.gather_nd(_log_p, indices, batch_dims=1)
        

        # _log_p = torch.from_numpy(_log_p.numpy())
        # a = torch.from_numpy(a.numpy())
        # AA = _log_p.gather(2, a.unsqueeze(-1)).squeeze(-1)
        # print(AA)
        # print(log_p)
        # print('DONE')
        

        # Get log_p corresponding to selected actions
        # log_p = tf.gather(tf.squeeze(_log_p,axis=-2), a, batch_dims=1) #_log_p.gather(2, a.unsqueeze(-1)).squeeze(-1)

        # Optional: mask out actions irrelevant to objective so they do not get reinforced
        if mask is not None:
            log_p[mask] = 0

        # Why??????
        # assert (log_p > -1000).data.all(), "Logprobs should not be -inf, check sampling procedure!"

        # Calculate log_likelihood
        return tf.reduce_sum(log_p, axis=1) # log_p.sum(1)

       



In [237]:
# import numpy as np

In [238]:
# # x = tf.constant([[True,  True], [False, False]])
# # assert not tf.reduce_any(x), "asdf"
# tf.keras.losses.SparseCategoricalCrossentropy(
#     from_logits=True, reduction='none')([]).dtype

In [239]:
# t = torch.tensor([[[2,3],[5,2]], [[6,1],[0,4]]])
# indices = torch.tensor([[[0, 0]],[[0, 0]]])
# torch.gather(t, 1, indices)
 
# # t = tf.constant([[[2,3],[5,2]], [[6,1],[0,4]]])
# # indices = tf.constant([[0,0,0]])
# # tf.gather_nd(t, indices=indices).numpy()

In [240]:
# # #tf.gather_nd()
# # t = tf.constant([[[1,0],[1,0]], [[0,1],[1,1]]])
# # t = embedded_inputs.adj_matrix
# # indices = tf.constant([[0], [0]])
# # indices = tf.expand_dims(indices, axis=-1)

# # ~tf.cast(tf.gather_nd(t, indices=indices, batch_dims=1), tf.bool)
# a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
# b = tf.keras.activations.tanh(a) 
# c = tf.math.tanh(a)
# b.numpy(), c.numpy()


In [241]:
model = AttentionDecoder(6, aggregation_graph='mean', n_heads=3)
makespan, ll, pi = model(embedded_inputs)
pi

<tf.Tensor: shape=(2, 3), dtype=int64, numpy=
array([[2, 1, 0],
       [2, 1, 0]])>

In [178]:
tf.range(4, dtype=tf.int64)

<tf.Tensor: shape=(4,), dtype=int64, numpy=array([0, 1, 2, 3])>

In [163]:

indices = [[[1, 0]], [[0, 1]]]
params = [[['a0', 'b0'], ['c0', 'd0']],
            [['a1', 'b1'], ['c1', 'd1']]]
#output = [['c0'], ['b1']]
tf.gather

<tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy=
array([[[0, 0],
        [1, 1]],

       [[0, 0],
        [1, 1]]], dtype=int32)>

In [197]:
a = tf.Tensor(
[[[0,1]
  [1,2]],

 [[0,0]
#   [1 1]]], shape=(2, 2, 2), dtype=int64)
# tf.Tensor(
# [[[-1.0886807e+00 -1.2036482e+00 -1.0126851e+00]
#   [-9.8199100e+00 -1.0000000e+09 -5.4357959e-05]]

#  [[-5.4339290e-01 -1.6555539e+00 -1.4773605e+00]
#   [-1.0000000e+09 -7.7492166e-01 -6.1755723e-01]]], shape=(2, 2, 3), dtype=float32)

SyntaxError: invalid syntax (<ipython-input-197-8979dfe3162c>, line 2)

In [None]:
embeddings.gather(
                        1,
                        torch.cat((state.first_a, current_node), 1)[:, :, None].expand(batch_size, 2, embeddings.size(-1))
                    ).view(batch_size, 1, -1)

In [156]:
_log_p, pi#tf.expand_dims(pi,axis=-1)

(<tf.Tensor: shape=(2, 1, 3), dtype=float32, numpy=
 array([[[0.53031623, 0.3536483 , 0.11603544]],
 
        [[0.6530212 , 0.23223129, 0.11474745]]], dtype=float32)>,
 <tf.Tensor: shape=(2, 1), dtype=int64, numpy=
 array([[1],
        [0]])>)

In [158]:
_log_p

tf.gather(tf.squeeze(_log_p,axis=-2), pi, batch_dims=1)

<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[0.3536483],
       [0.6530212]], dtype=float32)>

In [161]:
_log_p = torch.Tensor(_log_p.numpy())
pi = torch.from_numpy(tf.cast(pi, dtype=tf.int64).numpy()) 

In [168]:
log_p = _log_p.gather(2, pi.unsqueeze(-1)).squeeze(-1)
log_p.sum(1) 
log_p


tensor([[0.3536],
        [0.6530]])

In [134]:
# pi.unsqueeze(-1)
# tf.cast(pi, dtype=tf.int64)
_log_p

tensor([[[0.2306, 0.2705, 0.4989]],

        [[0.2119, 0.2444, 0.5437]]])

In [135]:
# seq = [torch.Tensor(i.numpy()) for i in seq]
# torch.stack(seq, 1)
pi

tensor([[2],
        [0]])

In [23]:
last_node = tf.constant([[0], [1]], dtype=tf.int64)

cur_node = selected[:, tf.newaxis]
cur_node #cur_node

tf.gather_nd(
tf.gather_nd(embedded_inputs.edge_features, tf.concat([last_node, cur_node], axis=1), batch_dims=1),
tf.reshape(tf.range(2), shape=(2,1)), # feature_indexfor processing time
batch_dims=0
)


<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[0.5743043 , 0.37116084, 0.        ],
       [0.70901567, 0.8906554 , 0.        ]], dtype=float32)>

In [24]:
A = tf.gather_nd(embedded_inputs.edge_features, tf.concat([last_node, cur_node], axis=1), batch_dims=1)[:,0:1]
A

<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[0.5743043 ],
       [0.70901567]], dtype=float32)>

In [25]:
tf.reshape(tf.range(2), shape=(2,1))

<tf.Tensor: shape=(2, 1), dtype=int32, numpy=
array([[0],
       [1]], dtype=int32)>

In [26]:
embedded_inputs.node_features[:,]

<tf.Tensor: shape=(2, 3, 5), dtype=float32, numpy=
array([[[0.60597825, 0.73336935, 0.13894716, 0.3126731 , 1.        ],
        [0.12816237, 0.1789931 , 0.75292546, 0.6621605 , 1.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ]],

       [[0.82309866, 0.732225  , 0.06905627, 0.6721289 , 1.        ],
        [0.8280144 , 0.2044694 , 0.617489  , 0.617701  , 1.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ]]],
      dtype=float32)>

In [27]:
cur_node[:,tf.newaxis,:] #embedded_inputs.node_features[:, ]
tf.gather_nd(embedded_inputs.edge_features, tf.concat([last_node, cur_node], axis=1), batch_dims=1)[:,0:1] + \
tf.gather_nd(embedded_inputs.node_features, cur_node, batch_dims=1)[:,0:1]

<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[0.70246667],
       [1.5321143 ]], dtype=float32)>

In [28]:
cur_node

tf.rank(tf.ones(30,10))

t = tf.constant([[[1, 1, 1, 4], [2, 2, 2, 4]], [[3, 3, 3, 4], [4, 4, 4, 4]]])
tf.rank(t)




<tf.Tensor: shape=(), dtype=int32, numpy=3>

In [32]:
_visited = tf.zeros((batch_size,1,3), dtype=tf.uint8)
_visited

<tf.Tensor: shape=(2, 1, 3), dtype=uint8, numpy=
array([[[0, 0, 0]],

       [[0, 0, 0]]], dtype=uint8)>

In [45]:
cur_node #[:,:,None] #, tf.ones([2,1,1])

<tf.Tensor: shape=(2, 1), dtype=int64, numpy=
array([[1],
       [0]])>

In [51]:
tf.concat([tf.reshape(tf.range(2, dtype=tf.int64),cur_node.shape), cur_node], axis=1)

<tf.Tensor: shape=(2, 2), dtype=int64, numpy=
array([[0, 1],
       [1, 0]])>

In [58]:
_visited.dtype

tf.uint8

In [59]:
batch_size, _, _ = _visited.shape
tf.tensor_scatter_nd_update( 
    tf.squeeze(_visited, axis=-2),
    tf.concat([tf.reshape(tf.range(batch_size, dtype=tf.int64),cur_node.shape), cur_node], axis=1),
    tf.ones((batch_size,), dtype=_visited.dtype)
)[:,tf.newaxis,:]
    
    
    #_visited, cur_node[:,:,None], tf.ones([2,1,1], dtype=tf.uint8))

<tf.Tensor: shape=(2, 1, 3), dtype=uint8, numpy=
array([[[0, 1, 0]],

       [[1, 0, 0]]], dtype=uint8)>

In [68]:
cur_node = selected[:, tf.newaxis]
cur_node #cur_node

tf.concat([cur_node, cur_node], axis=1)

tf.gather_nd(
tf.gather_nd(embedded_inputs.edge_features, tf.concat([cur_node, cur_node], axis=1), batch_dims=1),
tf.zeros((2,1), dtype=tf.int32)
)

tf.gather_nd(embedded_inputs.edge_features, tf.concat([cur_node, cur_node], axis=1), batch_dims=1)

<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)>

In [67]:
tf.zeros((2,), dtype=tf.int32)
tf.range(2)

<tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 1], dtype=int32)>

In [39]:
embedded_inputs.edge_features

<tf.Tensor: shape=(2, 3, 3, 3), dtype=float32, numpy=
array([[[[0.        , 0.        , 0.        ],
         [0.5743043 , 0.37116084, 0.        ],
         [0.        , 0.        , 1.        ]],

        [[0.5743043 , 0.37116084, 0.        ],
         [0.        , 0.        , 0.        ],
         [0.        , 0.        , 1.        ]],

        [[0.        , 0.        , 1.        ],
         [0.        , 0.        , 1.        ],
         [0.        , 0.        , 0.        ]]],


       [[[0.        , 0.        , 0.        ],
         [0.70901567, 0.8906554 , 0.        ],
         [0.        , 0.        , 1.        ]],

        [[0.70901567, 0.8906554 , 0.        ],
         [0.        , 0.        , 0.        ],
         [0.        , 0.        , 1.        ]],

        [[0.        , 0.        , 1.        ],
         [0.        , 0.        , 1.        ],
         [0.        , 0.        , 0.        ]]]], dtype=float32)>

In [27]:
probs_ = tf.squeeze(tf.exp(log_p), axis=-2)
tf.squeeze(mask, axis=-2)

NameError: name 'log_p' is not defined

In [33]:
embedded_inputs.node_features

<tf.Tensor: shape=(2, 3, 5), dtype=float32, numpy=
array([[[0.60597825, 0.73336935, 0.13894716, 0.3126731 , 1.        ],
        [0.12816237, 0.1789931 , 0.75292546, 0.6621605 , 1.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ]],

       [[0.82309866, 0.732225  , 0.06905627, 0.6721289 , 1.        ],
        [0.8280144 , 0.2044694 , 0.617489  , 0.617701  , 1.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ]]],
      dtype=float32)>

In [28]:
probs = torch.Tensor(log_p.numpy()).exp()[:, 0, :]
mask_ = torch.Tensor(mask.numpy())[:, 0, :]

probs, probs.multinomial(1) #.squeeze(1)


# _, selected = probs.max(1)
# assert not mask_.gather(1, selected.unsqueeze(
#                 -1)).data.any(), "Decode greedy: infeasible action has maximum probability"
# mask_.gather(1, selected.unsqueeze(-1)) #.data.any()
probs_

NameError: name 'log_p' is not defined

In [275]:
probs.multinomial(1)

tensor([[2],
        [0]])

In [294]:
# p = [.2, .3, .5]
# dist = tfp.distributions.Multinomial(total_count=1, probs=probs_)

p = [[.1, .2, .7], [.3, .3, .4]]  # Shape [2, 3]

dist = tfp.distributions.Multinomial(total_count=1, probs=probs_)
selected = tf.argmax(dist.sample(), axis=1)

<tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 0])>

In [215]:
selected = tf.math.argmax(tf.squeeze(tf.exp(log_p), axis=-2), axis=1)

assert not tf.reduce_any(tf.cast(tf.gather_nd(mask_, tf.expand_dims(selected, axis=-1), batch_dims=1), dtype=tf.bool)), "Decode greedy: infeasible action has maximum probability"

True

In [206]:
tf.expand_dims(selected, axis=-1), mask_, 

(<tf.Tensor: shape=(2, 1), dtype=int64, numpy=
 array([[1],
        [0]])>, tensor([[0., 0., 0.],
         [0., 0., 0.]]), <tf.Tensor: shape=(1,), dtype=int32, numpy=array([2], dtype=int32)>)

In [305]:
probs_


AssertionError: 

In [138]:
tf.broadcast_to(mask[None, :, None, :, :], shape=compatibility.shape)

NameError: name 'compatibility' is not defined

In [None]:
compatibility

compatibility[graph_mask[None, :, :, None, :].expand_as(compatibility)] = -1e10

In [49]:
def _make_heads(v, num_steps=None):
    assert num_steps is None or v.size(1) == 1 or v.size(1) == num_steps
    n_heads = 2
    return (
        v.contiguous().view(v.size(0), v.size(1), v.size(2), n_heads, -1)
        .expand(v.size(0), v.size(1) if num_steps is None else num_steps, v.size(2), n_heads, -1)
        .permute(3, 0, 1, 2, 4)  # (n_heads, batch_size, num_steps, graph_size, head_dim)
    )

from torch import nn
W_placeholder = nn.Parameter(torch.Tensor(2 * 6))
W_placeholder[None, None, :].expand(2, 1, W_placeholder.size(-1))

tensor([[[-7.8044e+00,  4.5678e-41, -7.8044e+00,  4.5678e-41, -4.2528e-01,
           1.6527e-01, -1.7343e-01, -3.0582e-01, -5.5208e-02, -2.9920e-01,
          -1.3644e-01,  4.3799e-01]],

        [[-7.8044e+00,  4.5678e-41, -7.8044e+00,  4.5678e-41, -4.2528e-01,
           1.6527e-01, -1.7343e-01, -3.0582e-01, -5.5208e-02, -2.9920e-01,
          -1.3644e-01,  4.3799e-01]]], grad_fn=<ExpandBackward>)

In [109]:
tf.ones((3,4)).dtype

tf.float32

In [51]:
_make_heads2(A, num_steps=4)

NameError: name '_make_heads2' is not defined

In [52]:
_make_heads2(A, num_steps=2)

NameError: name '_make_heads2' is not defined

In [53]:
A.shape
tf.broadcast_to(A, tf.TensorShape([-1, 4, None, None]))

AttributeError: type object 'A' has no attribute 'shape'

In [54]:
tf.split(
tf.matmul(tf.ones((5,3))[None,:,:], tf.ones((3,3*3))),
3,
axis=-1

)

[<tf.Tensor: shape=(1, 5, 3), dtype=float32, numpy=
 array([[[3., 3., 3.],
         [3., 3., 3.],
         [3., 3., 3.],
         [3., 3., 3.],
         [3., 3., 3.]]], dtype=float32)>,
 <tf.Tensor: shape=(1, 5, 3), dtype=float32, numpy=
 array([[[3., 3., 3.],
         [3., 3., 3.],
         [3., 3., 3.],
         [3., 3., 3.],
         [3., 3., 3.]]], dtype=float32)>,
 <tf.Tensor: shape=(1, 5, 3), dtype=float32, numpy=
 array([[[3., 3., 3.],
         [3., 3., 3.],
         [3., 3., 3.],
         [3., 3., 3.],
         [3., 3., 3.]]], dtype=float32)>]

In [55]:
tf.matmul(tf.ones((5,3))[None,:,:], tf.ones((3,3*3)))

<tf.Tensor: shape=(1, 5, 9), dtype=float32, numpy=
array([[[3., 3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3., 3.]]], dtype=float32)>

Parameter containing:
tensor([ 1.1431e+27,  1.7241e+25,  9.1084e-44,  0.0000e+00, -6.2840e-10,
         3.0753e-41, -2.3967e-33,  4.5706e-41,  7.2697e+31,  1.8730e+31],
       requires_grad=True)

In [56]:
import torch
embeddings = torch.Tensor(embedded_inputs.node_features.numpy())

torch.cat((state.first_a, current_node), 1)[:, :, None]

NameError: name 'state' is not defined

In [57]:
embedded_inputs.node_features

<tf.Tensor: shape=(2, 3, 5), dtype=float32, numpy=
array([[[0.60597825, 0.73336935, 0.13894716, 0.3126731 , 1.        ],
        [0.12816237, 0.1789931 , 0.75292546, 0.6621605 , 1.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ]],

       [[0.82309866, 0.732225  , 0.06905627, 0.6721289 , 1.        ],
        [0.8280144 , 0.2044694 , 0.617489  , 0.617701  , 1.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ]]],
      dtype=float32)>

In [58]:
import torch
from torch import nn

W_placeholder = nn.Parameter(torch.Tensor(2 * 5))


W_placeholder[None, None, :].expand(batch_size, 1, W_placeholder.size(-1)).shape

torch.Size([2, 1, 10])

In [59]:
tf.constant(W_placeholder[None, None, :].detach().numpy()).shape[-1]

10

TypeError: broadcast_to() got an unexpected keyword argument 'dim'

In [74]:
tf.broadcast_to?

[0;31mSignature:[0m [0mtf[0m[0;34m.[0m[0mbroadcast_to[0m[0;34m([0m[0minput[0m[0;34m,[0m [0mshape[0m[0;34m,[0m [0mname[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Broadcast an array for a compatible shape.

Broadcasting is the process of making arrays to have compatible shapes
for arithmetic operations. Two shapes are compatible if for each
dimension pair they are either equal or one of them is one. When trying
to broadcast a Tensor to a shape, it starts with the trailing dimensions,
and works its way forward.

For example,

>>> x = tf.constant([1, 2, 3])
>>> y = tf.broadcast_to(x, [3, 3])
>>> print(y)
tf.Tensor(
    [[1 2 3]
     [1 2 3]
     [1 2 3]], shape=(3, 3), dtype=int32)

In the above example, the input Tensor with the shape of `[1, 3]`
is broadcasted to output Tensor with shape of `[3, 3]`.

When doing broadcasted operations such as multiplying a tensor
by a scalar, broadcasting (usually) confers some time or sp

In [58]:
W_placeholder[None, None, :].shape

torch.Size([1, 1, 10])

In [None]:
        self.w = self.add_weight(shape=(input_shape[-1], self.units),
                                 initializer='random_normal',
                                 trainable=True)

In [41]:
tf.reduce_mean(embedded_inputs.node_features, axis=-2)

<tf.Tensor: shape=(2, 6), dtype=float32, numpy=
array([[ 1.2512864 , -0.11654303,  0.427403  ,  0.6303492 ,  1.0807096 ,
         0.7363432 ],
       [ 1.875542  , -0.47257408,  1.0240463 ,  0.74093527,  2.416504  ,
         1.2604753 ]], dtype=float32)>

In [28]:
import torch

In [39]:
torch.Tensor(embedded_inputs.node_features.numpy()).max(1)[0]

tensor([[1.6213, 0.0821, 0.7174, 1.1543, 2.0377, 1.2888],
        [2.3475, 0.0000, 1.2708, 1.3155, 3.2248, 2.1935]])

In [31]:
embedded_inputs.node_features.numpy()

array([[[ 1.6213439 , -0.49750236,  0.71744895,  0.73677236,
          2.0377152 ,  1.2887579 ],
        [ 1.1330049 ,  0.08210948,  0.13447253,  1.1542754 ,
          0.79937   ,  0.9202718 ],
        [ 0.9995105 ,  0.06576377,  0.43028754,  0.        ,
          0.40504336,  0.        ]],

       [[ 1.9401135 , -0.62696314,  1.2366322 ,  0.9072662 ,
          3.2248068 ,  2.1934643 ],
        [ 2.3475442 , -0.790759  ,  1.2707808 ,  1.3155396 ,
          2.8450027 ,  1.4678874 ],
        [ 1.338968  ,  0.        ,  0.5647262 ,  0.        ,
          1.1797022 ,  0.12007414]]], dtype=float32)

In [None]:
"""

"""
import tensorflow as tf
from tensorflow.keras import Model

from msp.layers import GGCNLayer
from msp.graphs import MSPSparseGraph
from msp.solutions import MSPState

class AttentionDecoder(Model):

    def __init__(self, 
                 units,
                 *args, 
                 activation='relu',
                 aggregation_graph='mean',
                 n_heads=8,
                 **kwargs):
        """ """
        super(AttentionDecoder, self).__init__(*args, **kwargs)
        self.aggregation_graph = aggregation_graph
        self.n_heads = n_heads

        embedding_dim = units
        self.W_placeholder = self.add_weight(shape=(2*embedding_dim,),
                                initializer='random_uniform', #Placeholder should be in range of activations (think)
                                trainable=True)

        graph_embed_shape = tf.TensorShape((None, units))
        self.fixed_context_layer = tf.keras.layers.Dense(units, use_bias=False)
        self.fixed_context_layer.build(graph_embed_shape)

        # For each node we compute (glimpse key, glimpse value, logit key) so 3 * embedding_dim
        project_node_embeddings_shape = tf.TensorShape((None, None, None, units))
        self.project_node_embeddings = tf.keras.layers.Dense(3*units, use_bias=False)
        self.project_node_embeddings.build(project_node_embeddings_shape)

        #
        # Embedding of first and last node
        step_context_dim = 2*units
        project_step_context_shape = tf.TensorShape((None, None, step_context_dim))
        self.project_step_context = tf.keras.layers.Dense(embedding_dim, use_bias=False)
        self.project_step_context.build(project_step_context_shape)
        
        #nn.Linear(step_context_dim, embedding_dim, bias=False)




        # self.context_layer = tf.keras.layers.Dense(units, use_bias=False)
        # self.mha_layer = None
        

        # dynamic router


        


        # self.initial_layer_1 = tf.keras.layers.Dense(units)
        # self.initial_layer_2 = tf.keras.layers.Dense(units)
        # self.ggcn_layers = [GGCNLayer(units, activation=activation)
        #                     for _ in range(layers)]
        
    def call(self, inputs, training=False):
        """ """
        state = MSPState(inputs)

        node_embedding = inputs.node_embed

        # AttentionModelFixed(node_embedding, fixed_context, *fixed_attention_node_data)
        fixed = self._precompute(node_embedding)
        return fixed 

        while not state.all_finished():

            log_p, mask = self._get_log_p(fixed, state)

        

        node_embed = inputs.node_features
        return self._get_parallel_step_context(node_embed)


    def _precompute(self, node_embedding, num_steps=1):

        graph_embed = self._get_graph_embed(node_embedding)

        fixed_context = self.fixed_context_layer(graph_embed)
        # fixed context = (batch_size, 1, embed_dim) to make broadcastable with parallel timesteps
        fixed_context = tf.expand_dims(fixed_context, axis=-2)

        glimpse_key_fixed, glimpse_val_fixed, logit_key_fixed  = tf.split(
            self.project_node_embeddings(tf.expand_dims(node_embedding, axis=-3)),
            num_or_size_splits=3,
            axis=-1
        )

        return glimpse_key_fixed

        # No need to rearrange key for logit as there is a single head
        fixed_attention_node_data = (
            self._make_heads(glimpse_key_fixed, num_steps),
            self._make_heads(glimpse_val_fixed, num_steps),
            logit_key_fixed
        )
        return AttentionModelFixed(node_embedding, fixed_context, *fixed_attention_node_data)

    def _get_graph_embed(self, node_embedding):
        """ """
        if self.aggregation_graph == "sum":
            graph_embed = tf.reduce_sum(node_embedding, axis=-2)
        elif self.aggregation_graph == "max":
            graph_embed = tf.reduce_max(node_embedding, axis=-2)
        elif self.aggregation_graph == "mean":
            graph_embed = tf.reduce_mean(node_embedding, axis=-2)
        else:  # Default: dissable graph embedding
            graph_embed = tf.reduce_sum(node_embedding, axis=-2) * 0.0

        return graph_embed

    def _make_heads(self, v, num_steps=None):

        assert num_steps is None or v.shape[1] == 1 or v.shape[1] == num_steps
        
        batch_size, _, num_nodes, embed_dims = v.shape
        num_steps = num_steps if num_steps else 1

        # M x B x N x V x H
        return tf.broadcast_to(
            tf.broadcast_to(v, shape=[batch_size, num_steps, num_nodes, embed_dims])[None,:,:,:,:],
            shape=[self.n_heads, batch_size, num_steps, num_nodes, embed_dims]
        )

    def _get_log_p(self, fixed, state, normalize=True):
        # Compute query = context node embedding
        
        # B x 1 x H
        query = fixed.context_node_projected + \
                self.project_step_context(self._get_parallel_step_context(fixed.node_embeddings, state))
        
        # Compute keys and values for the nodes
        glimpse_K, glimpse_V, logit_K = self._get_attention_node_data(fixed, state)

        graph_mask = None
        if self.mask_graph:
            # Compute the graph mask, for masking next action based on graph structure 
            graph_mask = state.get_graph_mask()  # Pending...........................................

        # Compute logits (unnormalized log_p)
        log_p, glimpse = self._one_to_many_logits(query, glimpse_K, glimpse_V, logit_K, mask, graph_mask)
        


    def _get_parallel_step_context(self, node_embedding, state, from_depot=False):
        """
        Returns the context per step, optionally for multiple steps at once 
        (for efficient evaluation of the model)
        """
        current_node = state.get_current_node()
        batch_size, num_steps = current_node.shape

        if num_steps == 1:  # We need to special case if we have only 1 step, may be the first or not
            if self.i.numpy()[0] == 0:
                # First and only step, ignore prev_a (this is a placeholder)
                # B x 1 x 2H
                return tf.broadcast_to(self.W_placeholder[None, None, :], 
                                       shape=[batch_size, 1, self.W_placeholder.shape[-1]])
                
            # else:
            #     return embeddings.gather(
            #         1,
            #         torch.cat((state.first_a, current_node), 1)[:, :, None].expand(batch_size, 2, embeddings.size(-1))
            #     ).view(batch_size, 1, -1)

    def _get_attention_node_data(self, fixed, state):
        return fixed.glimpse_key, fixed.glimpse_val, fixed.logit_key

    def _one_to_many_logits(self, query, glimpse_K, glimpse_V, logit_K, mask, graph_mask=None):
        pass




    

In [None]:
tf.boolean_mask

In [20]:
MSPSparseGraph._fields

('adj_matrix', 'node_features', 'edge_features', 'alpha')

In [21]:
dataset = make_sparse_data(4, msp_size=(1,2), random_state=2021)

In [25]:
graph = list(dataset.take(1))[0]
graph

MSPSparseGraph(adj_matrix=<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[0., 1., 1.],
       [1., 0., 1.],
       [1., 1., 0.]], dtype=float32)>, node_features=<tf.Tensor: shape=(3, 5), dtype=float64, numpy=
array([[0.27535218, 0.6795556 , 0.58162645, 0.97019879, 1.        ],
       [0.33048907, 0.96671416, 0.82589904, 0.30540018, 1.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ]])>, edge_features=<tf.Tensor: shape=(3, 3, 3), dtype=float64, numpy=
array([[[0.        , 0.        , 0.        ],
        [0.46459967, 0.12383996, 0.        ],
        [0.        , 0.        , 1.        ]],

       [[0.46459967, 0.12383996, 0.        ],
        [0.        , 0.        , 0.        ],
        [0.        , 0.        , 1.        ]],

       [[0.        , 0.        , 1.        ],
        [0.        , 0.        , 1.        ],
        [0.        , 0.        , 0.        ]]])>, alpha=<tf.Tensor: shape=(3, 1), dtype=float64, numpy=
array([[1.],
       [1.],
      

In [6]:
indices= tf.cast(tf.transpose(
    tf.concat([
        edge_index,
        tf.scatter_nd(
            tf.constant([[1],[0]]),
            edge_index,
            edge_index.shape
        )
    ], axis=-1)
), tf.int64)
indices

<tf.Tensor: shape=(6, 2), dtype=int64, numpy=
array([[0, 1],
       [0, 2],
       [1, 2],
       [1, 0],
       [2, 0],
       [2, 1]])>

In [7]:
adj_matrix = tf.sparse.SparseTensor(
    indices= indices,
    values = tf.ones([2*edge_index.shape[-1]]),
    dense_shape = [3,3]
)
adj_matrix

<tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7febac93e890>

In [13]:
tf.sparse.to_dense(tf.sparse.reorder(adj_matrix))

<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[0., 1., 1.],
       [1., 0., 1.],
       [1., 1., 0.]], dtype=float32)>

In [None]:
reorder

In [31]:
tf.ones([edge_index.shape[-1]])

<tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 1., 1.], dtype=float32)>

In [38]:
E = tf.sparse.SparseTensor(indices=[[0, 0], [1, 2], [2, 0]], values=[1, 2], dense_shape=[3, 4])
tf.sparse.to_dense()

ValueError: Dimensions 3 and 2 are not compatible

In [123]:
tf.ones((2,3))

<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1., 1., 1.],
       [1., 1., 1.]], dtype=float32)>

In [246]:
class GGCNLayer(Layer):

    # @validate_hyperparams
    def __init__(self, 
                 units,
                 *args, 
                 activation='relu', 
                 use_bias=True, 
                 normalization='batch',
                 aggregation='mean',
                 **kwargs):
        """ """
        super(GGCNLayer, self).__init__(*args, **kwargs)
        self.units = units
        self.activation = tf.keras.activations.get(activation)
        self.use_bias = use_bias
        self.normalization= normalization
        self.aggregation = aggregation

    def build(self, input_shape):
        """Create the state of the layer (weights)"""
        print('Build')
        node_features_shape = input_shape.node_features
        edge_featues_shape = input_shape.edge_features
        embedded_shape = tf.TensorShape((None, self.units))

        # _initial_projection_layer (think on it)

        with tf.name_scope('node'):
            with tf.name_scope('U'):
                self.U = tf.keras.layers.Dense(self.units, use_bias=self.use_bias)
                self.U.build(node_features_shape)

            with tf.name_scope('V'):
                self.V = tf.keras.layers.Dense(self.units, use_bias=self.use_bias)
                self.V.build(node_features_shape)

            with tf.name_scope('norm'):
                self.norm_h = {
                    "batch": tf.keras.layers.BatchNormalization(),
                    "layer": tf.keras.layers.LayerNormalization()
                }.get(self.normalization, None)
                if self.norm_h:
                    self.norm_h.build(embedded_shape)

        with tf.name_scope('edge'):
            with tf.name_scope('A'):
                self.A = tf.keras.layers.Dense(self.units, use_bias=self.use_bias)
                self.A.build(tf.TensorShape((None, node_features_shape[-1])))
            
            with tf.name_scope('B'):
                self.B = tf.keras.layers.Dense(self.units, use_bias=self.use_bias)
                self.B.build(node_features_shape)

            with tf.name_scope('C'):
                self.C = tf.keras.layers.Dense(self.units, use_bias=self.use_bias)
                self.C.build(edge_featues_shape)

            with tf.name_scope('norm'):
                self.norm_e = {
                    'batch': tf.keras.layers.BatchNormalization(),
                    'layer': tf.keras.layers.LayerNormalization(axis=-2)
                }.get(self.normalization, None)
                if self.norm_e:
                    self.norm_e.build(embedded_shape)
    
        super().build(input_shape)
 
    def call(self, inputs):
        """ """
        print('call')
        adj_matrix = inputs.adj_matrix
        h = inputs.node_features
        e = inputs.edge_features

        # Edges Featuers
        Ah = self.A(h)
        Bh = self.B(h)
        Ce = self.C(e)
        e = self._update_edges(e, [Ah, Bh, Ce])

        edge_gates = tf.sigmoid(e)

        # Nodes Features
        Uh = self.U(h)
        Vh = self.V(h)
        h = self._update_nodes(
            h,
            [Uh, self._aggregate(Vh, edge_gates, adj_matrix)]
        )

        outputs = MSPSparseGraph(adj_matrix, h, e, inputs.alpha)
        return inputs
        
    def _update_edges(self, e, transformations:list):
        """Update edges features"""
        Ah, Bh, Ce  = transformations
        e_new = tf.expand_dims(Ah, axis=1) + tf.expand_dims(Bh, axis=2) + Ce
        # Normalization
        batch_size, num_nodes, num_nodes, hidden_dim = e_new.shape
        if self.norm_e:
            e_new = tf.reshape(
                self.norm_e(
                    tf.reshape(e_new, [batch_size*num_nodes*num_nodes, hidden_dim])
                ), e_new.shape
            )
        # Activation
        e_new = self.activation(e_new)
        # Skip/residual Connection
        e_new = e + e_new
        return e_new

    def _update_nodes(self, h, transformations:list):
        """ """
        Uh, aggregated_messages = transformations
        h_new = tf.math.add_n([Uh, aggregated_messages])
        # Normalization
        batch_size, num_nodes, hidden_dim = h_new.shape
        if self.norm_h:
            h_new = tf.reshape(
                self.norm_h(
                    tf.reshape(h_new, [batch_size*num_nodes, hidden_dim])
                ), h_new.shape
            )
        # Activation
        h_new = self.activation(h_new)
        # Skip/residual Connection
        h_new = h + h_new
        return h_new

    def _aggregate(self, Vh, edge_gates, adj_matrix):
        """ """
        # Reshape as edge_gates
        Vh = tf.broadcast_to(
            tf.expand_dims(Vh, axis=1),
            edge_gates.shape
        )
        # Gating mechanism
        Vh = edge_gates * Vh
        
        # Enforce graph structure
        # mask = tf.broadcast_to(tf.expand_dims(adj_matrix,axis=-1), Vh.shape)
        # Vh[~mask] = 0

        # message aggregation
        if self.aggregation == 'mean':
            total_messages = tf.cast(
                tf.expand_dims(
                    tf.math.reduce_sum(adj_matrix, axis=-1),
                    axis=-1
                ),
                Vh.dtype
            )
            return tf.math.reduce_sum(Vh, axis=2) / total_messages
        
        elif self.aggregation == 'sum':
            return tf.math.reduce_sum(Vh, axis=2)


In [247]:
#

In [248]:

B, V, H = 1, 3, 2
h = tf.ones((B, V, H))


In [249]:
dataset = make_sparse_data(4, msp_size=(1,2), random_state=2020)
graphs = list(dataset.batch(2))
VVh = list(graphs)[0].edge_features
A = list(graphs)[0].adj_matrix
VVh.shape, A.shape

(TensorShape([2, 3, 3, 3]), TensorShape([2, 3, 3]))

In [252]:
# mask = tf.cast(tf.broadcast_to(
#     tf.expand_dims(A, axis=-1),
#     VVh.shape
# ), tf.bool)
# VVh[~mask] 
ggcn.variables

[<tf.Variable 'GGCN_Layer/node/U/kernel:0' shape=(5, 3) dtype=float32, numpy=
 array([[ 0.3418128 , -0.74846673,  0.4793586 ],
        [ 0.6570088 ,  0.8612539 ,  0.65150267],
        [-0.72393733, -0.22281748, -0.22313446],
        [ 0.79907125,  0.64225537, -0.03322494],
        [ 0.555322  ,  0.38448828,  0.36000997]], dtype=float32)>,
 <tf.Variable 'GGCN_Layer/node/U/bias:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>,
 <tf.Variable 'GGCN_Layer/node/V/kernel:0' shape=(5, 3) dtype=float32, numpy=
 array([[-0.26467794,  0.7648507 ,  0.719668  ],
        [ 0.2005971 , -0.62737715, -0.5316253 ],
        [-0.05412966, -0.2538466 , -0.8192339 ],
        [ 0.02352881,  0.5495698 ,  0.2812906 ],
        [ 0.08741534,  0.04532963,  0.46388656]], dtype=float32)>,
 <tf.Variable 'GGCN_Layer/node/V/bias:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>,
 <tf.Variable 'GGCN_Layer/node/norm/gamma:0' shape=(3,) dtype=float32, numpy=array([1., 1., 1.], d

In [251]:
units = graphs[0].edge_features.shape[-1]
ggcn = GGCNLayer(units=units, use_bias=True, name='GGCN_Layer')
output = ggcn(graphs[0])
output

Build
call


InvalidArgumentError: Incompatible shapes: [2,3,5] vs. [2,3,3] [Op:AddV2]

In [None]:
tf.cast(
    tf.expand_dims(
        tf.math.reduce_sum(list(graphs)[0].adj_matrix, axis=-1),
        axis=-1
    ),
    Vh.dtype
)
    tf.expand_dims(
    tf.math.reduce_sum(list(graphs)[0].adj_matrix, axis=-1),
    axis=-1).dtype

In [54]:
tf.math.reduce_max(graphs[0].edge_features, axis=-2 )[0]

<tf.Tensor: shape=(3, 3), dtype=float64, numpy=
array([[0.12417877, 0.31973648, 1.        ],
       [0.12417877, 0.31973648, 1.        ],
       [0.        , 0.        , 1.        ]])>

In [42]:
list(graphs)[0].adj_matrix

<tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
array([[[0., 1., 1.],
        [1., 0., 1.],
        [1., 1., 0.]],

       [[0., 1., 1.],
        [1., 0., 1.],
        [1., 1., 0.]]], dtype=float32)>

In [51]:
# ggcn = GGCNLayer(units=units, use_bias=True, name='GGCN_Layer')
# output = ggcn(graphs[0])
# output

In [11]:
graphs = list(make_data(1, msp_size=(1,2), random_state=2021).take(1))[0]

num_edges = 3


A = tf.sparse.SparseTensor(
    indices= tf.cast(tf.transpose(graphs['edge_index']), tf.int64),
    values = tf.ones([num_edges]),
    dense_shape = [3,3]
)

tf.sparse.to_dense(A)

<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[0., 1., 1.],
       [0., 0., 1.],
       [0., 0., 0.]], dtype=float32)>

In [34]:
tf.cast(tf.transpose(graphs['edge_index']), tf.int64)

<tf.Tensor: shape=(3, 2), dtype=int64, numpy=
array([[0, 1],
       [0, 2],
       [1, 2]])>

In [77]:
# ggcn.variables
tf.cast(tf.transpose(graphs['edge_index']), tf.int64)


<tf.Tensor: shape=(3, 2), dtype=int64, numpy=
array([[0, 1],
       [0, 2],
       [1, 2]])>

In [79]:
edge_index = tf.cast(tf.transpose(
    tf.concat([
        graphs['edge_index'],
        tf.scatter_nd(
            tf.constant([[1],[0]]),
            graphs['edge_index'],
            graphs['edge_index'].shape
        )
    ], axis=-1)
), tf.int32)
edge_index

<tf.Tensor: shape=(6, 2), dtype=int32, numpy=
array([[0, 1],
       [0, 2],
       [1, 2],
       [1, 0],
       [2, 0],
       [2, 1]], dtype=int32)>

In [137]:
model = tf.keras.Sequential([
    GGCNLayer(units=units, use_bias=True, name='GGCN_Layer')
])
model.compile(optimizer='sgd', loss='mse')
model.fit(dataset)

ValueError: in user code:

    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:855 train_function  *
        return step_function(self, iterator)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:845 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:1285 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:2833 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:3608 _call_for_each_replica
        return fn(*args, **kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:838 run_step  **
        outputs = model.train_step(data)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:795 train_step
        y_pred = self(x, training=True)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/base_layer.py:1030 __call__
        outputs = call_fn(inputs, *args, **kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/sequential.py:375 call
        self._build_graph_network_for_inferred_shape(inputs.shape, inputs.dtype)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/training/tracking/base.py:522 _method_wrapper
        result = method(self, *args, **kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/sequential.py:281 _build_graph_network_for_inferred_shape
        input_shape = tuple(input_shape)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/tensor_shape.py:868 __iter__
        raise ValueError("Cannot iterate over a shape with unknown rank.")

    ValueError: Cannot iterate over a shape with unknown rank.


In [86]:
model.weights

ValueError: Weights for model sequential have not yet been created. Weights are created when the Model is first called on inputs or `build()` is called with an `input_shape`.