## Adding Custom Operations to Tensorflow using Python

In [1]:
import tensorflow as tf
from tensorflow.python.framework import ops
import numpy as np

In [2]:
# Credit to @harpone for the example
# https://gist.github.com/harpone/3453185b41d8d985356cbe5e57d67342
def py_func(func, inp, Tout, stateful=True, name=None, grad=None):
    rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1e+8))
    tf.RegisterGradient(rnd_name)(grad)
    g = tf.get_default_graph()
    with g.gradient_override_map({"PyFunc": rnd_name}):
        return tf.py_func(func, inp, Tout, stateful=stateful, name=name)

In [3]:
# Credit to @harpone for the example
# https://gist.github.com/harpone/3453185b41d8d985356cbe5e57d67342
def square(x, name=None):
    with ops.name_scope(name, "MySquareFunc", [x]) as name:
        sqr_x = py_func(np.square,
            [x],
            [tf.float32],
            name=name,
            grad=square_grad)
        return sqr_x[0]
def square_grad(op, grad):
    x = op.inputs[0]
    return grad * 2 * x

In [4]:
# Credit to @harpone for the example
# https://gist.github.com/harpone/3453185b41d8d985356cbe5e57d67342
with tf.Session() as sess:
    x = tf.constant([1., 2.])
    y = square(x)
    print(x.eval(), y.eval(), tf.gradients(y, x)[0].eval())

[1. 2.] [1. 4.] [2. 4.]


## Implementing a Heap using Custom Ops

In [5]:
from heapq import heappush, heappop

In [6]:
class TFMinHeapPQ(object):
    """This clas enables a min heap priority queue 
    data structure to interface with tensorflow 
    computational graphs."""
    
    class Node(object):
        """This class is used as a container for vectors
        that can be compared using a hashing function."""
        def __init__(self, x, outer):
            self.x = x
            self.outer = outer
        def __str__(self):
            return str(self.x)
        def __repr__(self):
            return repr(self.x)
        def __eq__(self, y):
            return self.outer.h(self.x) == y.outer.h(y.x)
        def __ne__(self, y):
            return self.outer.h(self.x) != y.outer.h(y.x)
        def __gt__(self, y):
            return self.outer.h(self.x) > y.outer.h(y.x)
        def __lt__(self, y):
            return self.outer.h(self.x) < y.outer.h(y.x)
        def __ge__(self, y):
            return self.outer.h(self.x) >= y.outer.h(y.x)
        def __le__(self, y):
            return self.outer.h(self.x) <= y.outer.h(y.x)
    
    def __init__(self, n, i, v, h):
        """Initialize the data structure with name, 
        the number instances, the size of element vectors, 
        and the hashing function."""
        self.name = n
        self.instances = i
        self.vsize = v
        self.h = h
        self.reset()
        
    def reset(self):
        self.pq = [[] for _ in range(self.instances)]
        
    def __str__(self):
        return str(self.pq)
    
    def __repr__(self):
        return repr(self.pq)
        
    def _interact(self, operator, operand):
        """Given a vector of action probabilities, 
        and a vector argument, interact with the 
        priority queue."""
        operator = operator.reshape(
            (self.instances, 3))
        operand = operand.reshape(
            (self.instances, self.vsize))
        action = np.argmax(operator, axis=-1)
        result = []
        for i in range(self.instances):
            if action[i] == 0:
                heappush(
                    self.pq[i], 
                    TFMinHeapPQ.Node(
                        operand[i, :], 
                        self))
                result += [operand[i, :]]
            elif action[i] == 1:
                result += [self.pq[i][0].x]
            else:
                result += [heappop(self.pq[i]).x]
        return np.vstack(result)
    
    def _interact_grad(self, op, grad):
        """Currently this op is not differentiable, 
        and so gradients are zero."""
        operator = op.inputs[0]
        operand = op.inputs[1]
        return grad * 0, grad * 0
    
    def interact(self, operator, operand):
        """Connect the data structure to delayed computation, 
        as part of a computational graph."""
        with ops.name_scope(
                self.name,
                self.name,
                [operator, operand]) as name:
            result = py_func(
                self._interact,
                [operator, operand],
                [tf.float64],
                name=name,
                grad=self._interact_grad)
            return result[0]

In [18]:
with tf.Session() as sess:
    
    weights = np.random.normal(0, 1, (1, 2))
    tf_pq = TFMinHeapPQ(
        "MinHeap", 
        1, 
        2, 
        lambda x: np.sum(x * weights))
    
    y = []
    for i in range(10):
        operator = tf.constant([[1., 0., 0.]])
        operand = tf.constant(np.random.normal(0, 1, (1, 2)))
        y += [tf_pq.interact(operator, operand)]
    sess.run(y)
    print("After Insert:", tf_pq.pq)
    
    y = []
    for i in range(10):
        operator = tf.constant([[0., 0., 1.]])
        operand = tf.constant(np.zeros((1, 2)))
        y += [tf_pq.interact(operator, operand)]
    sess.run(y)
    print("After Delete:", tf_pq.pq)

After Insert: [[array([1.8347881 , 1.27536726]), array([1.80034747, 0.23048911]), array([ 2.11328604, -0.22488661]), array([0.98964783, 0.17513749]), array([-0.1929099 ,  1.59344752]), array([-0.80268199,  0.10572079]), array([-0.69127304,  0.271473  ]), array([ 0.03404603, -0.67181822]), array([-0.42221106,  0.42005549]), array([0.36772811, 0.26428235])]]
After Delete: [[]]
