In [None]:
import numpy as np


class MazeEnv:
    def __init__(self, grid):
        self.grid = np.array(grid)
        self.start = tuple(np.argwhere(self.grid == 2)[0])
        self.goal = tuple(np.argwhere(self.grid == 3)[0])
        self.reset()

    def reset(self):
        self.pos = self.start
        return self.pos

    def step(self, action):
        moves = [(-1,0),(1,0),(0,-1),(0,1)]  # up, down, left, right
        new_pos = (self.pos[0]+moves[action][0], self.pos[1]+moves[action][1])
        if (0 <= new_pos[0] < self.grid.shape[0] and
            0 <= new_pos[1] < self.grid.shape[1] and
            self.grid[new_pos] != 1):
            self.pos = new_pos
        reward = 1 if self.pos == self.goal else -0.01
        done = self.pos == self.goal
        return self.pos, reward, done


class SNN:
    def __init__(self, n_inputs, n_outputs):
        self.n_inputs = n_inputs
        self.n_outputs = n_outputs
        self.w = np.random.rand(n_outputs, n_inputs) * 0.5
        self.pre_trace = np.zeros(n_inputs)
        self.post_trace = np.zeros(n_outputs)
        self.lr = 0.01

    def forward(self, x):
        spikes = (x > 0).astype(float)
        out = np.dot(self.w, spikes)
        out_spikes = (out > 0.5).astype(float)
        self.pre_trace = 0.9 * self.pre_trace + spikes
        self.post_trace = 0.9 * self.post_trace + out_spikes
        return out_spikes

    def stdp(self, x, y):
     
        for i in range(self.n_outputs):
            for j in range(self.n_inputs):
                dw = self.lr * (self.pre_trace[j] * y[i] - x[j] * self.post_trace[i])
                self.w[i, j] += dw
        self.w = np.clip(self.w, 0, 1)


if __name__ == "__main__":
    # 0: empty, 1: wall, 2: start, 3: goal
    grid = [
        [2,0,1,0,3],
        [1,0,1,0,1],
        [0,0,0,0,0],
        [1,1,1,1,0]
    ]
    env = MazeEnv(grid)
    snn = SNN(n_inputs=4, n_outputs=4) 

    for episode in range(50):
        state = env.reset()
        done = False
        steps = 0
        while not done and steps < 50:
          
            x = np.zeros(4)
           
            for a, (dx,dy) in enumerate([(-1,0),(1,0),(0,-1),(0,1)]):
                nx, ny = state[0]+dx, state[1]+dy
                if 0 <= nx < env.grid.shape[0] and 0 <= ny < env.grid.shape[1]:
                    x[a] = 0 if env.grid[nx,ny]==1 else 1
            y = snn.forward(x)
            action = np.argmax(y)
            next_state, reward, done = env.step(action)
            snn.stdp(x, y)
            state = next_state
            steps += 1
        print(f"Episode {episode+1}: steps={steps}, reached goal={done}")

Episode 1: steps=50, reached goal=False
Episode 2: steps=50, reached goal=False
Episode 3: steps=50, reached goal=False
Episode 4: steps=50, reached goal=False
Episode 5: steps=50, reached goal=False
Episode 6: steps=50, reached goal=False
Episode 7: steps=50, reached goal=False
Episode 8: steps=50, reached goal=False
Episode 9: steps=50, reached goal=False
Episode 10: steps=50, reached goal=False
Episode 11: steps=50, reached goal=False
Episode 12: steps=50, reached goal=False
Episode 13: steps=50, reached goal=False
Episode 14: steps=50, reached goal=False
Episode 15: steps=50, reached goal=False
Episode 16: steps=50, reached goal=False
Episode 17: steps=50, reached goal=False
Episode 18: steps=50, reached goal=False
Episode 19: steps=50, reached goal=False
Episode 20: steps=50, reached goal=False
Episode 21: steps=50, reached goal=False
Episode 22: steps=50, reached goal=False
Episode 23: steps=50, reached goal=False
Episode 24: steps=50, reached goal=False
Episode 25: steps=50, rea