Skip to content

Commit

Permalink
Monitor hook dtypes
Browse files Browse the repository at this point in the history
* Customize in-memory hook data types, allowing spike tracking for very large models (O(100k)).
* `Model` method signature changes.
  • Loading branch information
rm875 committed Sep 5, 2024
1 parent 71358c9 commit 83b213f
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 21 deletions.
19 changes: 13 additions & 6 deletions sapicore/engine/network/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
""" Networks are graph representations of neuron ensembles connected by synapses. """
from typing import Sequence
from typing import Sequence, Type

import os
from itertools import compress

import matplotlib.pyplot as plt
import torch

import networkx as nx
from networkx import DiGraph
Expand All @@ -19,6 +19,8 @@
from sapicore.utils.constants import SYNAPSE_SPLITTERS
from sapicore.utils.io import DataAccumulatorHook, flatten, load_yaml, MonitorHook

import matplotlib.pyplot as plt

__all__ = ("Network",)


Expand Down Expand Up @@ -290,7 +292,11 @@ def _build(self):
self.root_lock = False

def add_monitor_hook(
self, steps: int = None, attrs: Sequence[str] = None, comps: Sequence[Component] = None
self,
steps: int = None,
attrs: Sequence[str] = None,
comps: Sequence[Component] = None,
dtype: Type = torch.float,
) -> dict:
"""Attach a forward hook to some or all network components, buffering accumulated output in memory.
Expand All @@ -313,12 +319,13 @@ def add_monitor_hook(
hooks[ensemble.identifier] = MonitorHook(
ensemble, ensemble.loggable_props if not attrs else attrs, steps
)

for synapse in self.get_synapses():
hooks[synapse.identifier] = MonitorHook(synapse, synapse.loggable_props if not attrs else attrs, steps)
hooks[synapse.identifier] = MonitorHook(
synapse, synapse.loggable_props if not attrs else attrs, steps, dtype
)
else:
for comp in comps:
hooks[comp.identifier] = MonitorHook(comp, comp.loggable_props if not attrs else attrs, steps)
hooks[comp.identifier] = MonitorHook(comp, comp.loggable_props if not attrs else attrs, steps, dtype)

return hooks

Expand Down
17 changes: 4 additions & 13 deletions sapicore/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,13 @@ def fit(
for synapse in self.network.get_synapses():
synapse.set_learning(False)

def predict(self, data: Tensor | Sequence[Tensor], **kwargs) -> Sequence:
def predict(self, data: Tensor, **kwargs) -> Sequence:
"""Predicts the labels of `data`.
Parameters
----------
data: Data or Tensor
Sapicore dataset or a standalone 2D tensor of data buffer, formatted sample X feature.
duration: int or Sequence of int
Duration of sample presentation. Simulates duration of exposure to a particular input.
If a list or a tensor is provided, the i-th sample in the batch is maintained for `duration[i]` steps.
rinse: int or Sequence of int
Null stimulation steps (0s in-between samples).
If a list or a tensor is provided, the i-th sample is followed by `rinse[i]` rinse steps.
data: Tensor
Standalone 2D tensor of data buffer, sample X feature.
Returns
-------
Expand All @@ -121,8 +113,7 @@ def predict(self, data: Tensor | Sequence[Tensor], **kwargs) -> Sequence:
raise NotImplementedError

def similarity(self, data: Tensor, metric: str | Callable, **kwargs) -> Tensor:
"""Performs rudimentary similarity analysis on the network's responses to `data`,
yielding a pairwise distance matrix.
"""Performs a similarity analysis on network responses to `data`, yielding a pairwise distance matrix.
Parameters
----------
Expand Down
8 changes: 6 additions & 2 deletions sapicore/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ class MonitorHook(Module):
Expected number of simulation steps. Used to preallocate the buffer.
If unknown, `torch.vstack()` is used, which may result in slower performance.
dtype:
Data type.
"""

def __init__(self, component: torch.nn.Module, attributes: Sequence[str], entries: int = None):
def __init__(self, component: torch.nn.Module, attributes: Sequence[str], entries: int = None, dtype=torch.float):
super().__init__()

self.component = component
Expand All @@ -65,6 +68,7 @@ def __init__(self, component: torch.nn.Module, attributes: Sequence[str], entrie
self.entries = entries

self.cache = {}
self.dtype = dtype

# (de)activate hook without removing it.
self.active = True
Expand Down Expand Up @@ -98,7 +102,7 @@ def fn(_, __, output):
else:
# preallocate if number of steps is known.
dim = [self.entries, len(output[attr])] + ([output[attr].shape[1]] if odim == 2 else [])
self.cache[attr] = torch.empty(dim)
self.cache[attr] = torch.empty(dim, dtype=self.dtype)

else:
if self.entries is None:
Expand Down

0 comments on commit 83b213f

Please sign in to comment.