In [None]:
import numpy as np

class SimpleSNN:
    def __init__(self, n_inputs, n_outputs):
        self.n_inputs = n_inputs
        self.n_outputs = n_outputs
        self.weights = np.random.rand(n_outputs, n_inputs) * 0.5
        self.pre_trace = np.zeros(n_inputs)
        self.post_trace = np.zeros(n_outputs)
        self.lr = 0.01

    def forward(self, x):
        # x: input spikes (binary vector)
        spikes = (x > 0).astype(float)
        out = np.dot(self.weights, spikes)
        out_spikes = (out > 0.5).astype(float)
        self.pre_trace = 0.9 * self.pre_trace + spikes
        self.post_trace = 0.9 * self.post_trace + out_spikes
        return out_spikes

    def stdp(self, x, y):
        # x: input spikes, y: output spikes
        for i in range(self.n_outputs):
            for j in range(self.n_inputs):
                dw = self.lr * (self.pre_trace[j] * y[i] - x[j] * self.post_trace[i])
                self.weights[i, j] += dw
        self.weights = np.clip(self.weights, 0, 1)

# Example usage
if __name__ == "__main__":
    snn = SimpleSNN(n_inputs=3, n_outputs=2)
    for step in range(10):
        x = np.random.randint(0, 2, 3)  # random input spikes
        y = snn.forward(x)
        snn.stdp(x, y)
        print(f"Step {step+1}: input={x}, output=