In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [20]:
'''
Purpose: This class represents a single module in the song bird system, aka HVC or RA
'''

class SpikingModule():
    def __init__(self, neuron_count, input_count, output_count, decay_rate, threshold, learning_rate, time_constant):
        '''
        Purpose:
        - the input matrix is the matrix of weights from the input to the neurons in the module  
        - the internal connectivity matrix is the matrix of weights from the neurons in the module to the neurons in the module  
        - the neuron states is a vector of the current state of each neuron in the module (membrane potential)
        - the clock is the current time step of the module
        - the firing history is a list of the time steps at which each neuron fired
        '''
        self.input_matrix = np.random.rand(input_count, neuron_count)
        self.internal_connectivity_matrix = np.random.rand(neuron_count, neuron_count)
        self.neuron_states = np.random.rand(neuron_count)
        self.clock = 0
        self.firing_history = []
        self.decay_rate = decay_rate
        self.learning_rate = learning_rate
        self.time_constant = time_constant
        self.threshold = threshold

    def stdp(self, learning_rate, time_constant):
        '''
        Purpose:
        - Implement Spike-Timing-Dependent Plasticity (STDP) learning rule to update weights of the internal_connectivity_matrix.
        Args:
        - learning_rate (float): The rate at which the weights should be updated.
        - time_constant (float): The time constant for the STDP.
        '''
        for i in range(len(self.internal_connectivity_matrix)):
            for j in range(len(self.internal_connectivity_matrix[i])):
                if i == j:
                    continue
                
                last_spike_i = self._find_last_spike(i)
                last_spike_j = self._find_last_spike(j)

                if last_spike_i is not None and last_spike_j is not None:
                    delta_t = self.clock - last_spike_i - last_spike_j
                    if delta_t > 0:
                        self.internal_connectivity_matrix[i][j] += learning_rate * np.exp(-delta_t / time_constant)
                    elif delta_t < 0:
                        self.internal_connectivity_matrix[i][j] -= learning_rate * np.exp(delta_t / time_constant)
                        
        self.internal_connectivity_matrix = np.clip(self.internal_connectivity_matrix, -1, 1)

    def _find_last_spike(self, neuron_index):
        '''
        Purpose:
        - Find the last time step at which the specified neuron fired.

        Args:
        - neuron_index (int): The index of the neuron in question.

        Returns:
        - last_spike (int or None): The time step of the last spike for the specified neuron, or None if the neuron has not fired.
        '''
        last_spike = None
        for time_step, fired_neurons in enumerate(reversed(self.firing_history)):
            if np.isin(neuron_index, fired_neurons):
                last_spike = self.clock - time_step - 1
                break
        return last_spike


    def step(self, input):
        self.neuron_states = self.neuron_states + self.internal_connectivity_matrix.dot(self.neuron_states) + self.input_matrix.T.dot(input)
        self.neuron_states = self.neuron_states * self.decay_rate
        self.neuron_states[np.where(self.neuron_states < 0)] = 0
        fired_neurons = np.where(self.neuron_states > self.threshold)
        self.firing_history.append(fired_neurons)
        self.neuron_states[fired_neurons] = 0
        self.clock += 1

        # Apply STDP learning rule with the desired learning_rate and time_constant
        learning_rate = 0.01
        time_constant = 20
        self.stdp(learning_rate, time_constant)

    def set_input_weights(self, input_weights):
        '''
        Purpose:
        - Set the input weights for the module.
        Args:
        - input_weights (array-like): The input weights to be set.
        '''
        self.input_matrix = np.array(input_weights)

    def set_internal_weights(self, internal_weights):
        '''
        Purpose:
        - Set the internal weights for the module.
        Args:
        - internal_weights (array-like): The internal weights to be set.
        '''
        self.internal_connectivity_matrix = np.array(internal_weights)

    def get_output(self):
        '''
        Purpose:
        - Compute the output of the module based on the current state of the neurons and the output weights.

        Returns:
        - output (array-like): The output of the module.
        '''
        output = self.output_weights.T.dot(self.neuron_states)
        return output

    def set_input_weights(self, input_weights):
        self.input_matrix = input_weights

    def set_internal_weights(self, internal_weights):
        self.internal_connectivity_matrix = internal_weights

    def set_output_weights(self, output_weights):
        self.output_weights = output_weights

In [21]:
import numpy as np

# Initialize the module with the desired parameters
module = SpikingModule(neuron_count=10, input_count=5, output_count=3, decay_rate=0.99, threshold=1, learning_rate=0.01, time_constant=20)

# Set the input, internal, and output weights as needed
input_weights = np.random.rand(5, 10)
internal_weights = np.random.rand(10, 10)
output_weights = np.random.rand(10, 3)

module.set_input_weights(input_weights)
module.set_internal_weights(internal_weights)
module.set_output_weights(output_weights)

# Run the module for a number of time steps with input
for t in range(100):
    input_data = np.random.rand(5)
    module.step(input_data)
    output_data = module.get_output()
    print(f"Time step {t}: Output {output_data}")

Time step 0: Output [0. 0. 0.]
Time step 1: Output [0. 0. 0.]
Time step 2: Output [2.20495319 2.29417458 2.1602368 ]
Time step 3: Output [0. 0. 0.]
Time step 4: Output [0.52735152 0.3677796  0.36542979]
Time step 5: Output [0. 0. 0.]
Time step 6: Output [0.69919744 1.08610639 1.0130894 ]
Time step 7: Output [0. 0. 0.]
Time step 8: Output [0.56089234 0.43395413 0.46635134]
Time step 9: Output [0. 0. 0.]
Time step 10: Output [0.47426948 0.33075972 0.32864643]
Time step 11: Output [0.91520326 0.94185466 0.93982754]
Time step 12: Output [0. 0. 0.]
Time step 13: Output [0. 0. 0.]
Time step 14: Output [0. 0. 0.]
Time step 15: Output [0.4159612  0.29009501 0.28824154]
Time step 16: Output [0.99644605 1.21221424 1.17449356]
Time step 17: Output [0. 0. 0.]
Time step 18: Output [1.15124291 1.6884907  1.20502752]
Time step 19: Output [0. 0. 0.]
Time step 20: Output [0.5837929  0.40714232 0.40454101]
Time step 21: Output [0. 0. 0.]
Time step 22: Output [0. 0. 0.]
Time step 23: Output [1.41442255 1