In [6]:
from typing import Callable, Optional, Union
from typing_extensions import TypedDict
import warnings

from typeguard import check_type
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gs
from ezephys import stimtools as st
from ezephys import pltools

In [8]:
TIMESTEP_WIDTH = 0.1

In [10]:
def get_exponential_kernel(
    tau_ms: float, kernel_length_ms: float, dt: float = TIMESTEP_WIDTH
) -> np.ndarray:
    t = np.arange(0, kernel_length_ms, dt)
    IRF_filter = np.exp(-t / tau_ms)
    IRF_filter = IRF_filter/sum(IRF_filter)
    return IRF_filter


def get_sigmoid(loc: float, sensitivity: float, gain: float) -> Callable[[np.ndarray], np.ndarray]:
    def sigmoid(x):
        return gain / (1 + np.exp(-(x - loc) / sensitivity))
    return sigmoid


def get_causal_filter(kernel: np.ndarray) -> Callable[[np.ndarray], float]:
    """Get a function that applies a filter to the end of a timeseries, returning a scalar."""
    flipped_kernel = np.flip(kernel).copy()
    if not np.isclose(flipped_kernel.sum(), 1.):
        warnings.warn('Kernel integral is not 1. Normalizing.')
        flipped_kernel /= flipped_kernel.sum()
    
    def causal_filter(timeseries: np.ndarray) -> float:
        """Filter the end of a timeseries, returning a scalar.
        
        Equivalent to the convolution of the filter kernel with
        timeseries, evaluated at the timepoint just after the end
        of the timeseries.
        
        """
        assert np.ndim(timeseries) == 1
        return np.dot(
            flipped_kernel[max(0, len(flipped_kernel) - len(timeseries)):],
            timeseries[max(0, len(timeseries) - len(flipped_kernel)):]
        )
    
    return causal_filter


class TwoCompartmentNeuronState(TypedDict):
    """The internal state of a two compartment neuron at a single time point."""
    Vd_linear: float  # Dendritic voltage generated by external inputs
    Vd_tot: float  # sigma(Vd_linear + Vs_tot)
    Vs_linear: float  # Somatic voltage generated by external inputs
    Vs_Na: float  # Voltage created by Na current in soma; sigma(kappa_Na * I_soma)
    Vs_tot: float  # Vs_linear + Vs_Na + Vd_tot


class TwoCompartmentNeuronStateRecord(dict):
    """Recording of the time-varying state of a two-compartment neuron."""
    def __init__(self, capacity: int, dt: float, **kwargs):
        """Initialze the recording.
        
        This method should not be considered public. 
        TwoCompartmentNeuronStateRecords should only be created by 
        a TwoCompartmentNeuron.
        
        """
        self.dt = dt
        self.__timestep = 0
        self.__capacity = capacity
        self.__finalized = False
        super().__init__(**kwargs)
        
    @property
    def time(self) -> np.ndarray:
        """Get a time support vector for plotting."""
        return np.arange(0, len(self)) * self.dt
        
    def __len__(self):
        if self.__finalized:
            return self.__capacity
        else:
            raise RuntimeError('Length of partially-filled recorder is undefined.')
        
    def _record(self, state: TwoCompartmentNeuronState):
        """Record a TwoCompartmentNeuronState."""
        if self.__finalized:
            raise RuntimeError('Recorder capacity exceeded.')
        
        # Check that argument state is a valid TwoCompartmentNeuronState.
        # This guarantees that all state variables have been set correctly.
        check_type('state', state, TwoCompartmentNeuronState)
        
        # Append state to record.
        for state_var_name, state_var_val in state.items():
            if state_var_name in self:
                self[state_var_name].append(state_var_val)
            else:
                self[state_var_name] = [state_var_val]
        
        if self.__timestep == self.__capacity - 1:
            # Recorder is full. Finalize automatically.
            self._finalize()
        else:
            self.__timestep += 1
                
    def _finalize(self):
        """Finish recording and make the recorder read-only."""
        if not self.__finalized:
            assert self.__capacity > self.__timestep
            self.__capacity = self.__timestep + 1  # Shrink capacity to fit.

            # Coerce state records to read-only arrays.
            # Store length for double-checking that all records have equal
            # num timesteps.
            state_record_lengths = []
            for state_var_name, state_var_val in self.items():
                self[state_var_name] = np.asarray(state_var_val)
                self[state_var_name].flags.writeable = False
                state_record_lengths.append(len(self[state_var_name]))
            
            # Raise an error if not all records have equal number of timesteps.
            if not all([self.__capacity == l for l in state_record_lengths]):
                raise RuntimeError(
                    'Recorder corrupted: not all state records have expected length.'
                )

            self.__finalized = True

class TwoCompartmentNeuron:
    def __init__(
        self, 
        sodium_kernel: np.ndarray, 
        dendritic_kernel: np.ndarray, 
        somatic_kernel: np.ndarray,
        sodium_nonlinearity: Union[Callable[[float], float], Callable[[np.ndarray], np.ndarray]],
        dendritic_nonlinearity: Union[Callable[[float], float], Callable[[np.ndarray], np.ndarray]],
    ):
        # Attach callable causal filters
        self._sodium_filter: Callable[[np.ndarray], float] = get_causal_filter(sodium_kernel)
        self._somatic_filter: Callable[[np.ndarray], float] = get_causal_filter(somatic_kernel)
        self._dendritic_filter: Callable[[np.ndarray], float] = get_causal_filter(dendritic_kernel)
        
        # Attach nonlinearities. Will only be called with scalars.
        self._sodium_nonlinearity = sodium_nonlinearity
        self._dendritic_nonlinearity = dendritic_nonlinearity
        
        # Allocate an attribute to hold state variables.
        # This will always be None, except during a simulation.
        self.__state: Optional[TwoCompartmentNeuronState] = None
        
    def __call__(
        self, 
        somatic_input: np.ndarray, 
        dendritic_input: np.ndarray, 
        dt: float = TIMESTEP_WIDTH,
        initial_state: Optional[TwoCompartmentNeuronState] = None
    ) -> TwoCompartmentNeuronStateRecord:
        """Simulate the response of the model to a set of inputs."""
        assert np.ndim(somatic_input) == 1
        assert np.ndim(dendritic_input) == 1
        assert len(somatic_input) == len(dendritic_input)
        
        try:
            # Set the initial state of the neuron.
            if initial_state is None:
                # Set all variables to zero by default.
                self.__state = TwoCompartmentNeuronState(
                    Vd_linear=0., Vd_tot=0., Vs_linear=0., Vs_Na=0., Vs_tot=0.
                )
            else:
                self.__state = initial_state

            # Allocate a TwoCompartmentStateRecorder to monitor the values of
            # state variables during the simulation. Capacity is len(input) + 1
            # because both the initial state and the state after the last
            # timestep of input are recorded.
            recorder = TwoCompartmentNeuronStateRecord(len(somatic_input), dt)

            # Run the simulation.
            for t_next in range(1, len(somatic_input) + 1):
                # Record current state.
                recorder._record(self.__state)

                # Update the internal state.
                self.__state = self._compute_next_state(
                    somatic_input, dendritic_input, t_next
                )
        finally:
            # Clear the state when the simulation is finished.
            self.__state = None
        
        # Finalize and return the record of states during the simulation.
        recorder._finalize()
        return recorder

    def _compute_next_state(
        self, somatic_input: np.ndarray, dendritic_input: np.ndarray, t_next: int
    ) -> TwoCompartmentNeuronState:
        next_state = TwoCompartmentNeuronState()
        
        # Compute values of all state variables at the next timestep.
        next_state['Vs_linear'] = self._somatic_filter(somatic_input[:t_next])
        next_state['Vs_Na'] = self._sodium_nonlinearity(
            self._sodium_filter(somatic_input[:t_next])
        )
        next_state['Vs_tot'] = (
            self.__state['Vs_linear'] + self.__state['Vs_Na'] + self.__state['Vd_tot']
        )
        next_state['Vd_linear'] = self._dendritic_filter(dendritic_input[:t_next])
        next_state['Vd_tot'] = self._dendritic_nonlinearity(
            self.__state['Vd_linear'] + self.__state['Vs_tot']
        )
        
        # Ensure that next state has been set correctly.
        check_type('next_state', next_state, TwoCompartmentNeuronState)
        
        return next_state