In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Main

In [None]:
## Izhikevich model code adapted from https://medium.com/geekculture/the-izhikevich-neuron-model-fb5d953b41e5
## (but with inhibitory neuron stuff removed)
class SpikingHN:
    def __init__(self, N):
        self.N = N                          # Number of neurons
        self.W = np.zeros((N, N))           # Weight matrix (variable "S" in original Izhikevich formulation)

        self.r = np.random.rand(N, 1)       # Random factor
        self.a = 0.02 * np.ones((N, 1))     # Time scale of membrane recovery `u`
        self.b = 0.2 * np.ones((N, 1))      # Sensitivity of `u` to fluctuations in membrane potential `v`
        self.c = -65 + 15 * self.r**2       # After-spike reset value for `v`
        self.d = 8 - 6 * self.r**2          # After-spike reset value for `u`
    
    def train(self, patterns, a=0.35, b=0.35):
        ## Training regime for low-activity patterns (which are more biologically plausible)
        ## Weight update equation taken from https://neuronaldynamics.epfl.ch/online/Ch17.S2.html#Ch17.E27
        patterns = np.array(patterns)
        activity = a        # Target activity level
        b_const = b         # A constant between 0 and 1
        c_prime = 1 / (2 * activity * (1 - activity) * self.N)    # A constant > 0

        # Incorporate patterns into network weights
        for pattern in patterns:
            zeta = (pattern + 1) / 2    # Derived from p^{\mu}_i = 2 \zeta^{\mu}_i - 1
            zeta = np.array(zeta).reshape(-1, 1)
            self.W += np.dot(zeta - b_const, zeta.T - activity)
        
        np.fill_diagonal(self.W, 0)
        self.W /= c_prime
    
    def forward(self, start_pattern, time_steps=1000):
        v = -65 * np.ones((self.N, 1))  # Initialize membrane potential
        u = self.b * -65                # Initialize membrane recovery

        firings_across_time = []
        voltage_across_time = []

        for t in range(1, time_steps + 1):
            # Configure input currents for all neurons
            if t == 1:
                # Initial external input (the starting pattern)
                I = np.array(start_pattern).reshape(-1, 1)
            else:
                # External input at all other time steps (just noise)
                I = 0.1 * (np.random.rand(self.N, 1) - 0.5)

            # When membrane potential `v` goes above 30 mV, we find the index, and append it to `fired`,
            # then reset `v` and membrane recovery variable `u`
            fired = np.where(v > 30)
            firings_across_time.append([t + 0 * fired[0], fired[0]])
            voltage_across_time.append(float(v[10]))

            # Reset activity of neurons that have fired
            for i in fired[0]:
                v[i] = self.c[i]
                u[i] += self.d[i]
            
            # Update input currents using weights and membrane potentials of fired neurons
            # Inspired by https://www.fabriziomusacchio.com/blog/2024-05-19-izhikevich_network_model/#input-current
            I += np.expand_dims(np.sum(self.W[:, fired[0]] @ v[fired[0]], axis = 1), axis = 1)

            # Update membrane potential `v` and recovery var `u`
            # Note: for `v` we have to do 0.5ms increments for numerical stability
            v += 0.5 * (0.04 * v**2 + 5 * v + 140 - u + I)
            v += 0.5 * (0.04 * v**2 + 5 * v + 140 - u + I)
            u = u + self.a * (self.b * v - u)

        output_pattern = np.where(v > 30, 1, -1)[0]      
        voltage_across_time = np.array(voltage_across_time)
        return output_pattern, firings_across_time, voltage_across_time