In [1]:
import numpy as np
import matplotlib.pyplot as plt
from itertools import zip_longest
SEED = 8

In [2]:
class SingleLayerSNN:
    
    def __init__(self, inputs, weights, trainings, Cm=4, Rm=5, V_thresh=30, V_rest=-65, V_spike=80, dT=0.01, rate=1):
        """
        Runs a LIF simulation on neuron and returns outputted voltage

                Parameters:
                        inputs (double[][][]): A 3d numpy array of the input voltages per timestep
                        weights (double[]): A numpy array of initial weights
                        outputs (double[][][]): A 3d numpy array of the output voltages per timestep used for teaching neuron
                Returns:
                        None
        """

        self.inputs = inputs
        self.weights = weights
        self.trainings = trainings
        self.Cm = Cm
        self.Rm = Rm
        self.V_thresh = V_thresh
        self.V_rest = V_rest
        self.V_spike = V_spike
        self.dT = dT # ms
        self.rate = rate # sec
        self._LIF_spikes = 0                    
        
    def LIF(self, I):
        """
        Runs a LIF simulation on neuron and returns outputted voltage

                Parameters:
                        I (double[]): A numpy array of input voltages in mV

                Returns:
                        V (double[]): A numpy array of the output voltages in mV
        """
        total_time = (I.size) * self.dT

        # an array of time
        time = np.arange(0, total_time, self.dT)

        # default voltage list set to resting volatage of -65mV
        V = (self.V_rest) * np.ones(len(time))

        did_spike = False

        # function member variable to track spikes
        self._LIF_spikes = 0

        for t in range(len(time)):
            # using "I - V(t)/Rm = Cm * dV/dT"
            dV = (I[t] - (V[t - 1] - self.V_rest) / self.Rm) / self.Cm

            # reset membrane potential if neuron spiked last tick
            if did_spike:
                V[t] = self.V_rest + dV * self.dT
            else:
                V[t] = V[t - 1] + dV * self.dT

            # check if membrane voltage exceeded threshold (spike)
            if V[t] > self.V_thresh:
                did_spike = True
                # set the last step to spike value
                V[t] = self.V_spike
                self._LIF_spikes += 1
            else:
                did_spike = False

        return V
    
    def voltage_to_output(self, V_input):
        V_output = np.array([])
        for v in V_input:
            V_output = np.append(V_output, 0 if v < self.V_spike else self.V_spike)
        return V_output
    
    def voltage_to_spike_rate(self, voltages, dT=None, rate=None):
        if not dT:
            dT = self.dT
        if not rate:
            rate = self.rate
            
#         print('voltages', voltages)
        
        def cond(V):
            return V >= self.V_spike
        
        spike_count = sum(cond(V) for V in voltages)
        
#         print('spike_count', spike_count)
        
        total_time_dT = len(voltages) * dT
#         print(f'total_time_dT: {total_time_dT} ({dT}ms)')
        
        spikes_per_dT = spike_count / total_time_dT
#         print(f'spikes_per_ms: {spikes_per_dT} (spikes/ms)')
        
        return spikes_per_dT * 1000 * rate     
    
    # returns the voltages of input and output neurons
    def feed_forward(self, inputs, trainings=[]):
        all_input_voltages = []
        all_output_voltages = []
        
        for input_set, training_set in zip_longest(inputs, trainings):
            input_voltages = []
            for V_input in input_set:
                input_voltages.append(self.LIF(V_input))
                
            output_inputs = []
            input_outputs = [] # DEBUG ONLY
            for weight_set in weights.T:
                weighted_sum = np.zeros(len(input_set[0]))
                for V_input, weight in zip(input_voltages, weight_set):
                    # filter for spikes b/c a neuron only outputs if it spikes
                    input_output = self.voltage_to_output(V_input)
                    input_outputs.append(input_output) # DEBUG ONLY
                    weighted = input_output * weight
                    weighted_sum = np.add(weighted_sum, weighted)
            
                output_inputs.append(weighted_sum)
            
            input_voltages = np.array(input_voltages)
            input_outputs = np.array(input_outputs)
            output_inputs = np.array(output_inputs)
            
#             print('input_voltages:')
#             print(input_voltages)
#             print('input_outputs:')
#             print(input_outputs)
#             print('output_inputs:')
#             print(output_inputs)
#             print('training_set:')
#             print(training_set)
            
            all_input_voltages.append(input_voltages)
            
            # inject training voltage if exists
            if isinstance(training_set, (list, np.ndarray)):
                for i, (output_input, training_input) in enumerate(zip(output_inputs, training_set)):
                    if isinstance(training_input, (list, np.ndarray)):
                        padded_training_input = np.pad(training_input, (0, len(output_inputs) - len(training_set)), "constant")
                        output_inputs[i] = output_input + padded_training_input
                
#             print('output_inputs after injecting training current')
#             print(output_inputs)
        
            # run LIF on output neurons
            output_voltages = []
            for V_input in output_inputs:
                output_voltages.append(self.LIF(V_input))
                            
            output_voltages = np.array(output_voltages)
            all_output_voltages.append(output_voltages)
            
#             print('output_voltages:')
#             print(output_voltages)
        
        all_input_voltages = np.array(all_input_voltages)
        all_output_voltages = np.array(all_output_voltages)
        
        return all_input_voltages, all_output_voltages
                
                    
    def train(self, epochs=100):
        a_corr = 0.0002
        w_max = 500
        w_decay = 2
        
        for epoch in range(epochs):
            print(f'Epoch: {epoch + 1}')
            
            all_input_voltages, all_output_voltages = self.feed_forward(self.inputs, self.trainings)
            
            # debug info
#             print()
#             print('------------------------------------------------')
#             print('all_input_voltages:')
#             print(all_input_voltages)
#             print('all_output_voltages:')
#             print(all_output_voltages)
            
#             print('weights:')
#             print(self.weights)
            
            # apply learning rule
            for input_voltages, output_voltages in zip(all_input_voltages, all_output_voltages):
#                 print('input_voltages', input_voltages)
                for i, (input_voltage_set, weight_set) in enumerate(zip(input_voltages, self.weights)):
#                     print('input_voltage_set', input_voltage_set)
                    input_rate = self.voltage_to_spike_rate(input_voltage_set)
#                     print(f'input_rate {i}:', input_rate)
        
                    for j, (output_voltage_set, weight) in enumerate(zip(output_voltages, weight_set)):
                        output_rate = self.voltage_to_spike_rate(output_voltage_set)
#                         print(f'\toutput_rate {j}:', output_rate)
                        
                        # adjust the weight using Hebb with decay
                        weight_change = a_corr * input_rate * output_rate - w_decay
#                         print('\told weight', weight)
#                         print('\tweight_change:', weight_change)
                        
                        if weight + weight_change < 0:
                            weights[i][j] = 0
                        elif weight + weight_change > w_max:
                            weights[i][j] = w_max
                        else:
                            weights[i][j] = weight + weight_change
                            
#                         print('\tnew weight', weights[i][j], '\n')
                        
            print(self.weights)
            
            # plot data
#             for i, (input_voltages, output_voltages) in enumerate(zip(all_input_voltages, all_output_voltages)):
#                 plt.figure(figsize=(20,10))
#                 plt.suptitle(f'Input: {i + 1}', fontsize=18)
#                 for input_voltage in input_voltages:
#                     plt.plot(input_voltage, 'b:', alpha=.5)
                    
#                 for output_voltage in output_voltages:
#                     plt.plot(output_voltage, 'r--', alpha=.5)
                
#             plt.show()
        
    def predict(self, inputs):
        all_input_voltages, all_output_voltages = self.feed_forward(self.inputs)
        print('all_output_voltages')
        print(all_output_voltages)
        for x, (input_voltages, output_voltages) in enumerate(zip(all_input_voltages, all_output_voltages)):
            print('input set:', x)
            for i, input_voltage_set in enumerate(input_voltages):
                print(f'\tinput {i}: {self.voltage_to_spike_rate(input_voltage_set)} spikes/{self.rate}s')
            print()
            for i, output_voltage_set in enumerate(output_voltages):
                print(f'\toutput {i}: {self.voltage_to_spike_rate(output_voltage_set)} spikes/{self.rate}s')
            print()


In [5]:
units_of_time = 3750
inputs = np.array([
    [[80] * units_of_time, [0] * units_of_time, [80] * units_of_time, [0] * units_of_time], # x: T, y: T --> T
    [[80] * units_of_time, [0] * units_of_time, [0] * units_of_time, [80] * units_of_time], # x: T, y: F --> F
    [[0] * units_of_time, [80] * units_of_time, [80] * units_of_time, [0] * units_of_time], # x: F, y: T --> F
    [[0] * units_of_time, [80] * units_of_time, [0] * units_of_time, [80] * units_of_time], # x: F, y: F --> F
])

weights = np.array([
    [.5, .5], 
    [.5, .5],
    [.5, .5],
    [.5, .5],
])

# current that gets injected to the output neurons
trainings = np.array([
    [[80] * units_of_time, [0] * units_of_time], # T
    [[0] * units_of_time, [80] * units_of_time], # F
    [[0] * units_of_time, [80] * units_of_time], # F
    [[0] * units_of_time, [80] * units_of_time], # f
])

and_network = SingleLayerSNN(inputs=inputs, weights=weights, trainings=trainings)

and_network.train(50)
and_network.predict(inputs)

# and_network.voltage_to_spike_rate([80, 80, 80, 80, 80, 0, 0, 0, 0, 0, 80, 80], dT=0.01, rate=1)


Epoch: 1
[[3.62 3.62]
 [0.   9.24]
 [3.62 4.12]
 [0.   8.74]]
Epoch: 2
[[ 6.74  6.74]
 [ 0.   17.48]
 [ 6.74  7.24]
 [ 0.   16.98]]
Epoch: 3
[[ 9.86  9.86]
 [ 0.   25.72]
 [ 9.86 10.36]
 [ 0.   25.22]]
Epoch: 4
[[12.98 12.98]
 [ 0.   33.96]
 [12.98 13.48]
 [ 0.   33.46]]
Epoch: 5
[[16.1 16.1]
 [ 0.  42.2]
 [16.1 16.6]
 [ 0.  41.7]]
Epoch: 6
[[19.22 19.22]
 [ 0.   50.44]
 [19.22 19.72]
 [ 0.   49.94]]
Epoch: 7
[[22.34 22.34]
 [ 0.   58.68]
 [22.34 22.84]
 [ 0.   58.18]]
Epoch: 8
[[25.46 25.46]
 [ 0.   66.92]
 [25.46 25.96]
 [ 0.   66.42]]
Epoch: 9
[[28.58 28.58]
 [ 0.   75.16]
 [28.58 29.08]
 [ 0.   74.66]]
Epoch: 10
[[31.7 31.7]
 [ 0.  83.4]
 [31.7 32.2]
 [ 0.  82.9]]
Epoch: 11
[[34.82 34.82]
 [ 0.   91.64]
 [34.82 35.32]
 [ 0.   91.14]]
Epoch: 12
[[37.94 37.94]
 [ 0.   99.88]
 [37.94 38.44]
 [ 0.   99.38]]
Epoch: 13
[[ 41.06  41.06]
 [  0.   108.12]
 [ 41.06  41.56]
 [  0.   107.62]]
Epoch: 14
[[ 44.18  44.18]
 [  0.   116.36]
 [ 44.18  44.68]
 [  0.   115.86]]
Epoch: 15
[[ 47.3  47.3

In [4]:
a = np.array([1, 2, 3, 4])
b = np.array([1, 2])
b = np.pad(b, (0, len(a) - len(b)), "constant")
c = [2, 3, 5]

print(a + b)

for x, y, z in zip_longest(a, b, c):
    print(x, y, z)

[2 4 3 4]
1 1 2
2 2 3
3 0 5
4 0 None
