In [None]:
def sample_discrete_forward(self, logits): 
    self.samples = self.sample_gumbel_k(tf.shape(logits))
    gumbel_sample = logits + self.samples
    threshold = tf.expand_dims(tf.nn.top_k(gumbel_sample, self.k, sorted=True)[0][:,-1], -1)
    y = tf.cast(tf.greater_equal(gumbel_sample, threshold), tf.float32)
    return y
    
def sample_discrete_backward(self, logits):     
    gumbel_sample = logits + self.samples
    threshold = tf.expand_dims(tf.nn.top_k(gumbel_sample, self.k, sorted=True)[0][:,-1], -1)
    y = tf.cast(tf.greater_equal(gumbel_sample, threshold), tf.float32)
    return y

@tf.custom_gradient
def sample_graph(self, logits, adjacency_matrix):

    # logits are the logits for the nodes of the graph with M nodes
    # computed by an upstream neural network
    # for the sake of simplicity, I assume logits to be of dimension (batch, M)
    # this could be extended to (batch, M, d) later to sample d graphs
    # the adjacency matrix is the input graph, it is a tensor of size (batch, M, M)
    
    # sample discretely with perturb and map; this is
    z_train = self.sample_discrete_forward(logits)
    # stop the g radients for z_train and the adjacency matrix
    # we will return gradients only for the logits
    z_train = tf.stop_gradient(z_train)
    adjacency_matrix_train = tf.stop_gradient(adjacency_matrix)
    # here we can now freely manipulate the adjacency matrix
    # by setting entries to 0 which have not been selected 
    # we can do this any way we please
    # we then return the adjacency matrix
    
    # compute the top-k discrete values
    threshold = tf.expand_dims(tf.nn.top_k(logits, self.k, sorted=True)[0][:,-1], -1)
    z_test = tf.cast(tf.greater_equal(logits, threshold), tf.float32)
    z_test = tf.stop_gradient(z_test)
    adjacency_matrix_test = tf.stop_gradient(adjacency_matrix)
    # here we can now freely manipulate the adjacency matrix
    # by setting entries to 0 which have not been selected 
    # we can do this any way we please
    # we then return the adjacency matrix
   
    # at training time we sample, at test time we take the argmax
    z_output = K.in_train_phase(adjacency_matrix_train, adjacency_matrix_test)
        
    def custom_grad(dy):
        
        # the tensor dy is of dimension (batch, M, M)
        # we know have to aggregate the gradients (each for an edge in the graph)
        # to obtain a gradient for each node
        # here we take the sum of the gradients of row i to be the gradient of node i
        # this can be shown to be reasonable for undirected graphs
        # but other aggregation schemes are possible for directed graphs
        node_gradients = tf.reduce_sum(dy, 1, keepdims=True)
        # node_gradients has now size (batch, M, 1)
        node_gradients = tf.reshape(node_gradients, [-1, M])
        # now we can treat node_gradients as the gradients wrt the nodes
        
        # we perturb (implicit diff) and then resuse sample for perturb and MAP
        map_dy = self.sample_discrete_backward(logits - (self._lambda*node_gradients))
        # we now compute the gradients as the difference (I-MLE gradients)
        grad = tf.math.subtract(z_train, map_dy)
        # return the gradient            
        return grad, adjacency_matrix

    return z_output, custom_grad