In [72]:
import numpy as np 
import tensorflow as tf
import matplotlib.pyplot as plt 

import collections
import hashlib
import numbers

from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import _Linear

here  I try to use the native RNN cell type in Tensorflow to implement a spiking network layer. An RNN is defined as any cell that has input, states, and a call function that takes the input and state, and produce next_state, and output;
https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/RNNCell

in the implementation, we assume spikes from m input cells are arrived, and train both $W_{in}$ and $W_{rec}$
state of the network are the three following variables $v_{membrane},\ t_{reset},\ g(t)$

In [90]:
class SpikingRnnCell(tf.contrib.rnn.RNNCell):
    # input and output of the network are spikes
    # states of the network are membrane voltage and synaptic input S 
    def __init__(self,num_units,tau_m=5.0,v_theta=1.0,v_reset=0.0,tau_s=5.0,tau_refract=3.0,reuse=None):
        super(SpikingRnnCell,self).__init__(_reuse=reuse)
        self.num_units=num_units
        self.tau_m=tau_m
        self.v_theta=v_theta
        self.v_reset=v_reset
        self.tau_s=tau_s
        self.tau_refract=tau_refract
        
        # variables 
        
    @property
    def state_size(self):
        return self.num_units
    @property
    def output_size(self):
        return 1
    
        
    #call routine is used by tensorflow to compute the output and next state of the network,
    def __call__(self,inputs,state):
        
        #output is some funtion of states
        # first slice up the state into three vectors, 
        v_mem=tf.slice(state,[0,0],[self.num_units,1])
        g_mem=tf.slice(state,[self.num_units,0],[self.num_units,1])
        t_mem=tf.slice(state,[2*self.num_units,0],[self.num_units,1])
        #
        spike=self.calculate_crossing_op(v_mem)
        v_update=tf.subtract(v_mem,tf.multiply(tf.cast(Spike,tf.float32),
                                               tf.constant(self.v_theta,shape=[self.num_units,1],dtype=tf.float32)))
        #
        spike_rec=tf.clip_by_value(tf.subtract(tf.transpose(tf.tile(spike,[1,self.num_units])),
                                            tf.eye(self.num_units,dtype=tf.float32)),0.0,100.0)
        spike_in=tf.transpose(tf.tile(tf.case(input,tf.float32),[1,self.num_units]))
        g_update=_Linear([spike_in,spike_rec],self.num_units,False)
        
        
        g_update=tf.add(tf.reduce_sum(tf.multiply(self.W_rec,spike_rec), 1, keepdims=True),
                        tf.reduce_sum(tf.multiply(self.W_in,spike_in), 1, keepdims=True))
        #
        dg_mem=tf.subtract(g_mem,tf.divide(g_mem,self.tau_s))
        g_mem_new=tf.add(g_update,tf.subtract(g_mem,dg_mem))
        #
        I_input=tf.add(tf.multiply(g_mem_new,v_update))
        #                
        t_subtract=tf.subtract(t_mem,tf.constant(1.0,shape=[self.num_units,1]))
        t_margin=tf.clip_by_value(t_subtract,0.0,100.0)
        t_mem_new=tf.add(t_margin,tf.multiply(spike,tf.constant(self.tau_refract,shape=[self.num_units,1],dtype=tf.float32)))
        update_trace=tf.cast(tf.equal(t_mem_new,tf.constant(0.0,shape=[self.num_units,1])),tf.float32)
        #
        dv_mem=tf.add(tf.constant(self.v_reset,shape=[self.num_units,1],dtype=tf.float32),
                      tf.multiply(update_trace,tf.divide(tf.subtract(I_input,v_update),tau_m)))
        v_mem_new=tf.add(v_update,dv_mem)

        
        
        return spike, tf.concat([v_mem_new,g_mem_new,t_mem_new],0)
    
    ## crossing fucntion 
    @tf.custom_gradient
    def calculate_crossing_op(self,x):
        x_norm=tf.divide(tf.subtract(x,tf.constant(self.v_theta,shape=[self.state_size,1])),
                         tf.constant(self.v_theta,shape=[self.state_size,1]))
        temp=tf.greater_equal(x,tf.constant(self.v_theta,shape=[self.state_size,1],dtype=tf.float32))
        def grad(dy):            
            return tf.maximum(tf.constant(0.0,dtype=tf.float32),tf.subtract(tf.constant(1.0,dtype=tf.float32),tf.abs(x_norm)))  
        return temp, grad
        

In [94]:
 snn_cell = SpikingRnnCell(4)



In [95]:
initial_state=snn_cell.zero_state(1,dtype=tf.float32)
initial_state

<tf.Tensor 'SpikingRnnCellZeroState_3/zeros:0' shape=(1, 4) dtype=float32>

In [86]:
def SNN(x):
    # Define a lstm cell with tensorflow
    snn_cell = SpikingRnnCell(1)

    # Get lstm cell output
    # outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)
    outputs, states = tf.nn.dynamic_rnn(cell=snn_cell, inputs=x, dtype=tf.float32)
    
    return output

In [87]:
timesteps=1000
num_input=1
# Need to clear the default graph before moving forward
tf.reset_default_graph()
graph = tf.Graph()
with graph.as_default():
    tf.set_random_seed(1)
    # tf Graph input
    X = tf.placeholder("float", [None,num_input,timesteps])

    spikes = SNN(X)

    
    # Initialize the variables (i.e. assign their default value)
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

TypeError: Failed to convert object of type <class '__main__.SpikingRnnCell'> to Tensor. Contents: <__main__.SpikingRnnCell object at 0xb32fecbe0>. Consider casting elements to a supported type.

In [66]:
from tensorflow.python.ops import clip_ops