In [1]:
import numpy as np

class RateCodedSNN:
    def __init__(self, n_inputs, n_outputs):
        self.n_inputs = n_inputs
        self.n_outputs = n_outputs
        self.weights = np.random.rand(n_outputs, n_inputs) * 0.5
        self.lr = 0.01

    def encode_rate(self, value, max_rate=20, duration=20):
        # value: float in [0, 1], returns spike train of length 'duration'
        rate = int(value * max_rate)
        spikes = np.zeros(duration)
        spikes[:rate] = 1
        np.random.shuffle(spikes)
        return spikes

    def forward(self, input_rates):
        # input_rates: array of floats in [0, 1], one per input neuron
        duration = 20
        input_spikes = np.array([self.encode_rate(val, duration=duration) for val in input_rates])
        output_spikes = np.zeros((self.n_outputs, duration))
        for t in range(duration):
            x_t = input_spikes[:, t]
            out = np.dot(self.weights, x_t)
            output_spikes[:, t] = (out > 0.5).astype(float)
        # Return average firing rate for each output neuron
        return output_spikes.mean(axis=1)

    def stdp(self, input_rates, output_rates):
        # Simple Hebbian-like update for demonstration
        for i in range(self.n_outputs):
            for j in range(self.n_inputs):
                dw = self.lr * input_rates[j] * output_rates[i]
                self.weights[i, j] += dw
        self.weights = np.clip(self.weights, 0, 1)

# Example usage
if __name__ == "__main__":
    snn = RateCodedSNN(n_inputs=3, n_outputs=2)
    for step in range(10):
        input_rates = np.random.rand(3)  # values in [0, 1]
        output_rates = snn.forward(input_rates)
        snn.stdp(input_rates, output_rates)
        print(f"Step {step+1}: input_rates={input_rates.round(2)}, output_rates={output_rates.round(2)}")

Step 1: input_rates=[0.98 0.36 0.37], output_rates=[0.55 0.1 ]
Step 2: input_rates=[0.98 0.12 0.18], output_rates=[0.2 0. ]
Step 3: input_rates=[0.07 0.04 0.14], output_rates=[0. 0.]
Step 4: input_rates=[0.59 0.55 0.85], output_rates=[0.65 0.35]
Step 5: input_rates=[0.76 0.34 0.29], output_rates=[0.35 0.05]
Step 6: input_rates=[0.57 0.14 0.2 ], output_rates=[0.15 0.  ]
Step 7: input_rates=[0.87 0.78 0.16], output_rates=[0.7 0.1]
Step 8: input_rates=[0.37 0.6  0.39], output_rates=[0.3 0.2]
Step 9: input_rates=[0.63 0.88 0.44], output_rates=[0.65 0.25]
Step 10: input_rates=[0.08 0.2  0.96], output_rates=[0.2 0.2]
