In [49]:
import numpy as np 
import tensorflow as tf
import matplotlib.pyplot as plt 

import collections
import hashlib
import numbers

from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import _Linear

here  I try to use the native RNN cell type in Tensorflow to implement a spiking network layer. An RNN is defined as any cell that has input, states, and a call function that takes the input and state, and produce next_state, and output;
https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/RNNCell

in the implementation, we assume spikes from m input cells are arrived, and train both $W_{in}$ and $W_{rec}$
state of the network are the three following variables $v_{membrane},\ t_{reset},\ g(t)$

In [99]:
_INPUT_WEIGHT_NAME = "W_in"
_RECURRENT_WEIGHT_NAME = "W_rec"

def _calcualte_crossings(x,threshold):
    """input :x : a 2D tensor with batch x n 
    outputs a tensor with the same size as x 
    and values of 0 or 1 depending on comparison between 
    x and threshold""" 
    @tf.custom_gradient
    def crossings(x):
        dtype=x.dtype
        shape=x.get_shape()
        thresholds=tf.constant(threshold,shape=[shape[0].value,shape[1].value],dtype=dtype)
        # if it has one row 
        res=tf.greater_equal(x,thresholds)
        def grad(dy):
            # calculate 1-|x|
            temp=1-tf.abs(x)
            dyres=tf.maximum(temp,0.0)
            return dyres
        return tf.cast(res,dtype=dtype), grad
    z=crossings(x)
    return z 

_SNNStateTuple = collections.namedtuple("SNNStateTuple", ("v_mem","spike","t_reset", "S_rec","S_in"))

class SNNStateTuple(_SNNStateTuple):
  """Tuple used by SNN Cells for `state_variables `, and output state.

  Stores five elements: `(v_mem,spike, t_reset, S_rec, S_in)`, in that order. Where `v_mem` is the hidden state
  , spike is output, `S_rec` and 'S_in' are spike history, and t_reset refractory history.

  Only used when `state_is_tuple=True`.
  """
  __slots__ = ()

  @property
  def dtype(self):
    (v_mem, spike,t_reset, S_rec, S_in ) = self
    if v_mem.dtype != spike.dtype:
      raise TypeError("Inconsistent internal state: %s vs %s" %
                      (str(v_mem.dtype), str(spike.dtype)))
    return spike.dtype


class SNNCell(tf.contrib.rnn.RNNCell):
  """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).

  Args:
    num_units: int, The number of units in the GRU cell.
    activation: Nonlinearity to use.  Default: `tanh`.
    reuse: (optional) Python boolean describing whether to reuse variables
     in an existing scope.  If not `True`, and the existing scope already has
     the given variables, an error is raised.
    kernel_initializer: (optional) The initializer to use for the weight and
    projection matrices.
    bias_initializer: (optional) The initializer to use for the bias.
  """

  def __init__(self,
               num_units,
               tau_m=5.0,
               v_theta=1.0,
               v_reset=0.0,
               tau_s=5.0,
               tau_refract=3.0,
               activation=None,
               reuse=None,
               kernel_initializer=None,
               bias_initializer=None,
               state_is_tuple=False):
    super(SNNCell, self).__init__(_reuse=reuse)
    self._num_units = num_units
    self.tau_m=tau_m
    self.v_theta=v_theta
    self.v_reset=v_reset
    self.tau_s=tau_s
    self.tau_refract=tau_refract
    self._activation = activation or math_ops.tanh
    self._kernel_initializer = kernel_initializer
    self._bias_initializer = bias_initializer
    self._state_is_tuple= state_is_tuple
    self._weight_linear = None
    self._calculate_crossing= None

  @property
  def state_size(self):
    return (SNNStateTuple(self._num_units,
                          self._num_units,
                          self._num_units,
                          [self._num_units,self._num_units],
                          [self._num_units,self._num_units]) if self._state_is_tuple else self._num_units)
                          

  @property
  def output_size(self):
    return self._num_units

  def call(self, inputs, state):
    """Spiking Neuron Cell (SNN).

    Args:
      inputs: `2-D` tensor with shape `[batch_size x input_size]`.
      state: An `SNNStateTuple` of state tensors, shaped as following 
              `[batch_size x self.state_size]`
              `[batch_size x self.state_size]`
              `[batch_size x self.state_size]`
              `[batch_size x self.state_size x self.state_size]`
              `[batch_size x self.state_size x self.state_size]`
        `[batch_size x self.state_size]`.

    Returns:
      A pair containing the new output, and the new state as SNNStateTuple
    """
    if self._state_is_tuple:
        v_mem,spike, t_reset, S_rec, S_in =state
    else: 
        logging.warn("%s: Please use state tuple ", self)
    # initialize crossing function 
    if self._calculate_crossing is None:
        self._calculate_crossing = _calcualte_crossings(v_mem,self.v_theta)
    
    # get spikes 
    spike=self._calculate_crossing(v_mem)
    v_reseting=tf.scalar_mul(self.v_theta,spike)
    v_update=tf.subtract(v_mem,v_reseting)
    # get conductance 
    
    # update membrane 
    
    # return variables 
    
    
    
    
    if self._gate_linear is None:
      bias_ones = self._bias_initializer
      if self._bias_initializer is None:
        bias_ones = init_ops.constant_initializer(1.0, dtype=inputs.dtype)
      with vs.variable_scope("gates"):  # Reset gate and update gate.
        self._gate_linear = _Linear(
            [inputs, state],
            2 * self._num_units,
            True,
            bias_initializer=bias_ones,
            kernel_initializer=self._kernel_initializer)

    value = math_ops.sigmoid(self._gate_linear([inputs, state]))
    r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)

    r_state = r * state
    if self._candidate_linear is None:
      with vs.variable_scope("candidate"):
        self._candidate_linear = _Linear(
            [inputs, r_state],
            self._num_units,
            True,
            bias_initializer=self._bias_initializer,
            kernel_initializer=self._kernel_initializer)
    c = self._activation(self._candidate_linear([inputs, r_state]))
    new_h = u * state + (1 - u) * c
    return new_h, new_h


#@tf.custom_gradient   
#def calculate_crossing_op(self,x):
#    x_norm=tf.divide(tf.subtract(x,tf.constant(self.v_theta,shape=[self.state_size,1])),
#                     tf.constant(self.v_theta,shape=[self.state_size,1]))
#    temp=tf.greater_equal(x,tf.constant(self.v_theta,shape=[self.state_size,1],dtype=tf.float32))
#    def grad(dy):            
#        return tf.maximum(tf.constant(0.0,dtype=tf.float32),tf.subtract(tf.constant(1.0,dtype=tf.float32),tf.abs(x_norm)))  
#    return temp, grad

example implementation 

In [101]:
values231 = np.array([
    [[1], [2], [3]],
    [[2], [3], [4]]
])

# Batch size = 3, sequence length = 5, number input = 2, shape=(3, 5, 2)
values352 = np.array([
    [[1, 4], [2, 5], [3, 6], [4, 7], [5, 8]],
    [[2, 5], [3, 6], [4, 7], [5, 8], [6, 9]],
    [[3, 6], [4, 7], [5, 8], [6, 9], [7, 10]]
])

import tensorflow as tf
tf.reset_default_graph()

tf_values231 = tf.constant(values352, dtype=tf.float32)
lstm_cell = SNNCell(num_units=100,state_is_tuple=False)
outputs, state = tf.nn.dynamic_rnn(cell=lstm_cell, dtype=tf.float32, inputs=tf_values231)

print(outputs)
# tf.Tensor 'rnn_3/transpose:0' shape=(2, 3, 100) dtype=float32
#print(state.c)
# tf.Tensor 'rnn_3/while/Exit_2:0' shape=(2, 100) dtype=float32
#print(state.h)
# tf.Tensor 'rnn_3/while/Exit_3:0' shape=(2, 100) dtype=float32
cell_outputs=[]
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    output_run, state_run = sess.run([outputs, state])
    cell_outputs.append(output_run)
    
output_run

Tensor("rnn/transpose_1:0", shape=(3, 5, 100), dtype=float32)


array([[[-0.09491698,  0.04844373,  0.15038799, ...,  0.04932952,
         -0.0933247 ,  0.12681684],
        [-0.19831467,  0.14181629,  0.2857799 , ...,  0.0807085 ,
         -0.16672042,  0.2721757 ],
        [-0.30884397,  0.2543453 ,  0.40748757, ...,  0.10056059,
         -0.2279777 ,  0.41125804],
        [-0.42058885,  0.36732054,  0.516152  , ...,  0.11257821,
         -0.28062946,  0.53184116],
        [-0.52734804,  0.47143427,  0.6124262 , ...,  0.11921699,
         -0.3262758 ,  0.6309334 ]],

       [[-0.14542075,  0.06875345,  0.2145503 , ...,  0.03947791,
         -0.0980949 ,  0.16520159],
        [-0.27697694,  0.18345661,  0.3850823 , ...,  0.0647111 ,
         -0.17341049,  0.32615378],
        [-0.3999127 ,  0.30627543,  0.519294  , ...,  0.08065322,
         -0.23507242,  0.46613288],
        [-0.51322204,  0.42067796,  0.62571687, ...,  0.09016851,
         -0.28729486,  0.5807098 ],
        [-0.6145111 ,  0.52154833,  0.71157813, ...,  0.09531144,
         -0.33

<tf.Tensor 'IdentityN_1:0' shape=(1, 1) dtype=float32>