Skip to content

Commit

Permalink
Merge pull request #33 from cplab/analog
Browse files Browse the repository at this point in the history
Engine Patch
  • Loading branch information
matham authored Jul 13, 2024
2 parents 8e0942f + 12826b9 commit 950cd4e
Show file tree
Hide file tree
Showing 14 changed files with 131 additions and 112 deletions.
32 changes: 16 additions & 16 deletions .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ name: Python Application

on:
push:
branches:
- merger
- main
pull_request:
branches:
- main
Expand All @@ -13,27 +10,30 @@ jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.x
uses: actions/setup-python@v1
- uses: actions/checkout@v4
- name: Set up Python 3.11
uses: actions/setup-python@v5
with:
python-version: 3.x
python-version: 3.11
- name: Install dependencies
run: |
python3 -m pip install --upgrade pip virtualenv wheel setuptools
- name: Lint with pycodestyle
run: |
python3 -m pip install flake8
python3 -m flake8 . --count --ignore=E125,E126,E127,E128,E203,E402,E741,E731,W503,F401,W504,F841 --show-source --statistics --max-line-length=120 --exclude=__pycache__,.tox,.git/,doc/
python3 -m flake8 . --count --ignore=E125,E126,E127,E128,E203,E226,E402,E741,E731,W503,F401,W504,F841 --show-source --statistics --max-line-length=120 --exclude=__pycache__,.tox,.git/,doc/
linux:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.x
uses: actions/setup-python@v1
- uses: actions/checkout@v4
- name: Set up Python 3.11
uses: actions/setup-python@v5
with:
python-version: 3.x
python-version: 3.11
- name: Install dependencies
run: |
python3 -m pip install --upgrade pip virtualenv wheel setuptools
- name: Make sdist
run: python3 setup.py sdist --formats=gztar
- name: Install dependencies
Expand Down Expand Up @@ -65,11 +65,11 @@ jobs:
docs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.x
uses: actions/setup-python@v1
- uses: actions/checkout@v4
- name: Set up Python 3.11
uses: actions/setup-python@v5
with:
python-version: 3.x
python-version: 3.11
- name: Install dependencies
run: |
python3 -m pip install --upgrade pip virtualenv wheel setuptools m2r2
Expand Down
2 changes: 1 addition & 1 deletion sapicore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
-----------
"""

__version__ = "0.3.0"
__version__ = "0.3.3"
24 changes: 12 additions & 12 deletions sapicore/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ def __init__(self, name: str = "", labels: list | NDArray | Tensor = None, axis:
self.axis = axis
self.labels = np.array(labels)

def __getitem__(self, index: slice) -> Tensor:
def __getitem__(self, index: Any) -> Tensor:
return self.labels[index]

def __setitem__(self, index: slice, values: Any):
def __setitem__(self, index: Any, values: Any):
self.labels[index] = np.array(values)


Expand Down Expand Up @@ -205,11 +205,11 @@ def __init__(
# passes silently if not implemented by the user.
self._standardize()

def __getitem__(self, index: slice):
def __getitem__(self, index: Any):
"""Calls :meth:`access` to slice into the data or access specific file(s), returning the value(s) at `index`."""
return self.access(index)

def __setitem__(self, index: slice, values: Tensor):
def __setitem__(self, index: Any, values: Tensor):
"""Sets buffer values at the given indices to `values`."""
self.modify(index, values)

Expand Down Expand Up @@ -289,7 +289,7 @@ def _standardize(self):
"""
pass

def access(self, index: slice, axis: int = None) -> Tensor:
def access(self, index: Any, axis: int = None) -> Tensor:
"""Specifies how to access data by mapping indices to actual samples (e.g., from file(s) in `root`).
The default implementation slices into `self.buffer` to accommodate the trivial cases where the user has
Expand All @@ -301,7 +301,7 @@ def access(self, index: slice, axis: int = None) -> Tensor:
Parameters
----------
index: slice
index: Any
Index(es) to slice into.
axis: int, optional
Expand All @@ -315,15 +315,15 @@ def access(self, index: slice, axis: int = None) -> Tensor:
"""
return self.buffer.index_select(axis, torch.as_tensor(index)) if axis is not None else self.buffer[index]

def load(self, indices: slice = None):
def load(self, indices: Any = None):
"""Populates the `buffer` tensor buffer and/or `descriptors` attribute table by loading one or more files
into memory, potentially selecting only `indices`.
Since different datasets and pipelines call for different formats, implementation is left to the user.
Parameters
----------
indices: slice
indices: Any
Specific indices to include, one for each file.
Returns
Expand All @@ -340,15 +340,15 @@ def load(self, indices: slice = None):
"""
pass

def modify(self, index: slice, values: Tensor):
def modify(self, index: Any, values: Tensor):
"""Set or modify data values at the given indices to `values`.
The default implementation edits the `buffer` field of this :class:`Data` object.
Users may wish to override it in cases where the buffer is not used directly.
Parameters
----------
index: slice
index: Any
Indices to modify.
values: Tensor
Expand Down Expand Up @@ -453,13 +453,13 @@ def sample(self, method: Callable, axis: int = 0, **kwargs):
# trim buffer and labels, returning a new partial dataset without mutating the original.
return self.trim(index=subset, axis=axis)

def trim(self, index: slice, axis: int = None):
def trim(self, index: Any, axis: int = None):
"""Trims this instance by selecting `indices`, potentially along `axis`, returning a subset of the original
dataset in terms of both buffer entries and labels/descriptors. Does not mutate the underlying object.
Parameters
----------
index: slice
index: Any
Index(es) to retain.
axis: int, optional
Expand Down
17 changes: 4 additions & 13 deletions sapicore/data/sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,11 @@ def __call__(self, frame: DataFrame, group_keys: str | list[str], n: int | float
frame["index"] = frame.index
grouped = frame.groupby(group_keys, group_keys=False)

if self.stratified:
# convert `n` to fraction if need be.
if isinstance(n, int):
frac = len(frame["index"].tolist()) * n
# convert `n` to integer if need be.
if isinstance(n, float):
n = int(n * len(frame["index"].tolist()))

# perform stratified sampling of `frac` out of every group.
subset = grouped.apply(lambda x: x.sample(frac=frac, replace=self.replace))

else:
# convert `n` to integer if need be.
if isinstance(n, float):
n = int(n * len(frame["index"].tolist()))

subset = grouped.apply(lambda x: x.sample(n, replace=self.replace))
subset = grouped.apply(lambda x: x.sample(n, replace=self.replace))

return subset["index"].tolist()

Expand Down
4 changes: 4 additions & 0 deletions sapicore/engine/component/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def __init__(self, identifier: str = None, configuration: dict = None, device: s
self.simulation_step = 0
self.dt = DT

# we don't know what attributes derivative component classes might introduce, but we want them initialized.
for key, value in kwargs.items():
setattr(self, key, value)

def configure(self, configuration: dict[str, Any] = None, log_destination: str = ""):
"""Applies a configuration to this object by adding the keys of `configuration` as instance attributes,
initializing their values, and updating the `_config_props_` tuple to reflect the new keys.
Expand Down
25 changes: 9 additions & 16 deletions sapicore/engine/network/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,19 +314,6 @@ def add_data_hook(self, data_dir: str, steps: int, *args: Component) -> list:

return hooks

# To micromanage the forward/backward sweeps, subclass Network and override summation(), forward(), backward().
@staticmethod
def summation(synaptic_input: list[torch.tensor]) -> torch.tensor:
"""Adds up inputs from multiple synapse objects onto the same ensemble, given as rows.
Note
----
If your model requires some preprocessing of inputs to the postsynaptic neuron, it can be implemented
by overriding this method.
"""
return torch.sum(torch.vstack(synaptic_input), dim=0)

def backward(self) -> None:
"""Processes a backward sweep for this network object.
Expand Down Expand Up @@ -377,9 +364,14 @@ def forward(self, data: torch.tensor) -> dict:
ensemble_ref = self.graph.nodes[ensemble]["reference"]

if ensemble_ref.identifier not in self.roots:
# apply a summation function to synaptic data flowing into this ensemble (torch.sum by default).
# apply an aggregation function to synaptic data flowing into this ensemble.
if incoming_synapses:
integrated_data = self.summation([synapse.output for synapse in incoming_synapses]).to(self.device)
inputs = [synapse.output for synapse in incoming_synapses]
ids = [synapse.identifier for synapse in incoming_synapses]

# aggregation is (micro)managed at the neuron level; torch.sum is used by default.
integrated_data = ensemble_ref.aggregate(inputs, identifiers=ids).to(self.device)

else:
integrated_data = ensemble_ref.input

Expand All @@ -388,7 +380,8 @@ def forward(self, data: torch.tensor) -> dict:
external = [data[self.roots.index(ensemble_ref.identifier)]] if isinstance(data, list) else [data]
feedback = [synapse.output for synapse in incoming_synapses]

integrated_data = self.summation(external + feedback)
ids = [f"ext{z}" for z in range(len(external))] + [synapse.identifier for synapse in incoming_synapses]
integrated_data = ensemble_ref.aggregate(external + feedback, identifiers=ids)

# forward current ensemble.
ensemble_ref(integrated_data)
Expand Down
31 changes: 28 additions & 3 deletions sapicore/engine/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class Neuron(Component):
Warning
-------
When defining `equation` for a custom neuron model, the present value of `voltage` should NOT be added to the
right hand side. Do NOT multiply by DT. These operations will be performed as part of the generic Euler forward.
right hand side. Do NOT multiply by DT. These operations will be performed within the Integrator.
"""

Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(self, equation: Callable = None, integrator: Integrator = None, **k
def num_units(self):
"""Number of functional units represented by this object.
Neurons are singletons by coercion, as they are meant to express unit dynamics.
Neurons are singletons by coercion, as they are meant to express and encapsulate unit dynamics.
Derivatives of :class:`~engine.ensemble.Ensemble` can modify this property and duplicate units as necessary.
"""
Expand All @@ -108,7 +108,7 @@ def forward(self, data: Tensor) -> dict:
Raises
------
NotImplementedError
The forward method must be implemented by each derived class.
The forward method must be implemented by derivative classes.
"""
raise NotImplementedError
Expand Down Expand Up @@ -138,3 +138,28 @@ def inject(self, current: Tensor):
"""
self.voltage = self.voltage + current

@staticmethod
def aggregate(inputs: list[Tensor], identifiers: list[str] = None) -> Tensor:
"""Determines how presynaptic inputs from multiple sources should be aggregated.
By default, neurons sum their inputs. However, many use cases may require more sophistication.
Shunting inhibition, for instance, can be expressed with torch.div (or torch.prod, if the source
synapse is expected to send the inverse).
Parameters
----------
inputs: list of Tensor
Input arriving at this layer, synaptic or external.
identifiers: list of str, optional
Labels by which to micromanage input aggregation. Since some inputs may not be
synaptic, users are responsible for passing identifiers in an order matching that of the input tensors.
Note
----
If your model requires identifier-dependent preprocessing of synaptic inputs to this neuron (e.g., a
combination of addition and multiplication), it can be implemented by overriding this method.
"""
return torch.sum(torch.vstack(inputs), dim=0)
11 changes: 7 additions & 4 deletions sapicore/engine/neuron/analog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Analog neurons may perform normalization or provide otherwise transformed input to downstream layers.
"""
from torch import tensor, Tensor
from torch import Tensor
from sapicore.engine.neuron import Neuron

__all__ = ("AnalogNeuron",)
Expand All @@ -31,7 +31,10 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)

def forward(self, data: Tensor) -> dict:
"""Adds input `data` to the numeric state stored in the instance attribute tensor `voltage`.
"""Updates the numeric state stored in the instance attribute tensor `voltage` to `data`.
These default analog neurons integrate the total input impinging on them on every simulation step.
Parameters
----------
Expand All @@ -50,9 +53,9 @@ def forward(self, data: Tensor) -> dict:
"""
# update internal representation of input current for tensorboard logging purposes.
self.input = tensor([data.detach().clone()]) if not data.size() else data.detach().clone()
self.input = data
self.voltage = data

self.voltage = self.voltage.add(data)
self.simulation_step += 1

# return current state(s) of loggable attributes as a dictionary.
Expand Down
12 changes: 12 additions & 0 deletions sapicore/engine/neuron/spiking/LIF.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ class LIFNeuron(SpikingNeuron):
tau_ref: float or Tensor
Refractory period (e.g., 1.0).
cycle_length: int, optional
Oscillatory cycle period, required to time optional resetting of the refractory period.
release_phase: int, optional
Oscillation phase at which to release all neurons from refractory mode, if required.
References
----------
`LIF Tutorial <https://compneuro.neuromatch.io/tutorials/W2D3_BiologicalNeuronModels/student/W2D3_Tutorial1.html>`_
Expand Down Expand Up @@ -110,5 +116,11 @@ def forward(self, data: Tensor) -> dict:
self.refractory_steps = relu(self.refractory_steps - 1)
self.simulation_step += 1

if hasattr(self, "release_phase") and hasattr(self, "cycle_length"):
# voltage will start to accumulate at a particular phase, canceling the refractory period across units.
if self.simulation_step % self.cycle_length == self.release_phase:
self.refractory_steps = torch.zeros_like(self.refractory_steps)
self.voltage = self.volt_rest

# return current state(s) of loggable attributes as a dictionary.
return self.loggable_state()
Loading

0 comments on commit 950cd4e

Please sign in to comment.