# Neuron Models

In [None]:
import numpy as np
from dataclasses import dataclass, field
from typing import Callable

## Perceptron
The perceptron is a binary neuron that has an intant return represented as a threshold dot product. The boundary function $ w \cdot x+b = 0 $ is used for classification.

In [None]:
def perceptron(w, x, b=0.0):
  """
  Binary Threshold Unit
  
  :param w: weight vector, parameters, how much each feature matters
  :param x: input vector, features, the data you feed in
  :param b: bias term, offset, constant added to move decision boundary

  Examples:
        >>> perceptron([1, 1], [0.2, 0.3])
        1
        >>> perceptron([1, 1], [-0.2, -0.3])
        0
  """
  return 1 if (np.dot(w, x) + b) > 0 else 0

## Rate Neuron

Each neuron is a "function" that computes a weighted sum then apply nonlinearity. Without linearity then stacking neurons is pointless, all the layers collapse into one and the network is only as powerful as a linear function (hyperplane, affine map).
$$ y = W_2(W_1x) = W_2W_1x $$
Common choices for nonlinearity are: ReLU, tanh, sigmoid, softplus.

In [None]:
def rate_neuron(w, x, b=0.0, nonlinearity="relu"):
    """
    Rate Neuron
    
    :param w: weight vector, parameters, how much each feature matters
    :param x: input vector, features, the data you feed in
    :param b: bias term, offset, constant added to move decision boundary
    :param nonlinearity: Description
    """
    z = float(np.dot(w, x) + b)
    if nonlinearity == "relu":
        # gate negative inputs
        return np.maximum(0.0, z)
    if nonlinearity == "tanh":
        # bound in [-1,1], good for stable dynamics in reservoir/recurrent networks.
        return np.tanh(z)
    if nonlinearity == "sigmoid":
        # bound in [0,1], good for probability like outputs or on/off.
        return 1.0 / (1.0 + np.power(np.e, z))
    if nonlinearity == "softplus":
        # smooth version of relu
        return np.log(1.0 + np.power(np.e, z))
    raise ValueError("Unknown nonlinearity")

## Leaky Rate Neuron
The previous defined rate neuron still has no time or memory, the leaky neuron adds state (memory).
$$ x[t] = x[t-1] + (dt/tau) * (-x[t-1] + I[t-1]) $$
Here the input x is the internal state of the neuron and decays each timestep. Input to the neuron pushes it up/down. The data class below shows a leakyrate neuron that can be used in a network. External inputs drive pulses.

In [None]:
def relu(z: np.ndarray) -> np.ndarray:
    return np.maximum(0.0, z)

@dataclass
class LeakyRateNeuron:
    """Single leaky rate neuron (leaky integrator + static nonlinearity).

    This model maintains an internal state x(t) that integrates input I(t) over time
    while exponentially decaying ("leaking") toward 0.

    Continuous-time form:
        dx/dt = (-x + I(t)) / tau
        r(t)  = f(gain * x(t) + bias)

    Discrete-time Euler update used here:
        x <- x + (dt/tau) * (-x + I_in)
        r <- f(gain * x + bias)

    Parameters
    ----------
    dt : float
        Simulation time step in seconds.
    tau : float
        Leak time constant in seconds. Larger tau => slower dynamics / more memory.
    gain : float
        Scales the state before applying the nonlinearity.
    bias : float
        Additive offset before applying the nonlinearity.
    f : callable
        Nonlinearity mapping an array-like input to array-like output.
        Common choices: ReLU, tanh, sigmoid, softplus.
    x : float
        Internal state (memory) of the neuron.

    Notes
    -----
    This class is network-ready: the input to `step(I_in)` is provided externally,
    so I_in can come from synapses (W @ r), an external stimulus, noise, etc.
    """
    dt: float = field(default=1e-3)
    tau: float = field(default=20e-3)
    gain: float = field(default=1.0)
    bias: float = field(default=0.0)
    f: Callable[[np.ndarray], np.ndarray] = field(default=relu)
    x: float = field(default=0.0)


    def reset(self, x0: float = 0.0) -> None:
        self.x = float(x0)

    def step(self, I_in: float) -> float:
        """
        Advance one timestep with external input I_in.
          x <- x + (dt/tau)*(-x + I_in)
          r <- f(gain*x + bias)
        """
        alpha = self.dt / self.tau
        self.x = self.x + alpha * (-self.x + float(I_in))
        r = float(self.f(np.array(self.gain * self.x + self.bias)))
        return r


## Spiking Neuron
The spiking neuron is almost the same as the leaky rate neuron, but the output is an event (spike) of 0 or 1. This is a meaningful shift because we are representing information via timing so there is sparse outputs (more efficient), patterns in spiking times matter (temporal coding), and is a natural fit for event-driven systems.

In [None]:
@dataclass
class LIFNeuron:
    """Single LIF spiking neuron (leaky integrator + threshold + reset).

    Continuous-time form (current-based LIF):
        dv/dt = (-(v - v_rest) + R * I(t)) / tau_m

    Discrete-time Euler update:
        v <- v + (dt/tau_m) * (-(v - v_rest) + R * I_in)

    Spiking rule:
        if v >= v_th: spike = 1, then v <- v_reset
        optional refractory: hold at v_reset for a fixed time after spiking

    Parameters
    ----------
    dt : float
        Simulation time step in seconds.
    tau_m : float
        Membrane time constant in seconds. Larger => slower voltage dynamics.
    v_rest : float
        Resting voltage (baseline the voltage leaks toward).
    v_reset : float
        Voltage after a spike.
    v_th : float
        Spike threshold.
    R : float
        Input gain (often interpreted as membrane resistance). In toy units,
        it just scales how strongly I_in pushes the voltage.
    refractory : float
        Refractory period in seconds (0 disables refractory behavior).
    v : float
        Current membrane voltage (internal state).

    Notes
    -----
    This class is network-ready: provide I_in from synapses, external stimulus,
    noise, etc. The `step()` returns (spike, v).
    """

    dt: float = 1e-3
    tau_m: float = 20e-3
    v_rest: float = 0.0
    v_reset: float = 0.0
    v_th: float = 1.0
    R: float = 1.0
    refractory: float = 5e-3

    v: float = 0.0

    # Derived / internal fields
    _alpha: float = field(init=False, repr=False)
    _refrac_steps: int = field(init=False, repr=False)
    _refrac_count: int = field(init=False, repr=False, default=0)

    def __post_init__(self) -> None:
        if self.tau_m <= 0:
            raise ValueError("tau_m must be > 0")
        if self.dt <= 0:
            raise ValueError("dt must be > 0")

        self._alpha = self.dt / self.tau_m
        self._refrac_steps = int(round(self.refractory / self.dt)) if self.refractory > 0 else 0
        self._refrac_count = 0

        # Default initial voltage: start at rest unless user already set v
        # (If you prefer explicit, remove this line and require calling reset.)
        if self.v == 0.0 and self.v_rest != 0.0:
            self.v = float(self.v_rest)

    def reset(self, v0: float | None = None) -> None:
        """Reset internal state (voltage and refractory counter)."""
        self.v = float(self.v_rest if v0 is None else v0)
        self._refrac_count = 0

    def step(self, I_in: float) -> tuple[int, float]:
        """Advance one timestep with external input current/drive I_in.

        Returns
        -------
        spike : int
            1 if the neuron spiked this step, else 0.
        v : float
            The membrane voltage after the update (and after reset if spiked).
        """
        # Refractory: hold at reset, do not integrate
        if self._refrac_count > 0:
            self._refrac_count -= 1
            self.v = float(self.v_reset)
            return 0, self.v

        # Leak + integrate
        self.v = self.v + self._alpha * (-(self.v - self.v_rest) + self.R * float(I_in))

        # Threshold -> spike + reset
        if self.v >= self.v_th:
            self.v = float(self.v_reset)
            if self._refrac_steps > 0:
                self._refrac_count = self._refrac_steps
            return 1, self.v

        return 0, self.v
