Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/learning 2f on chip #528

Merged
merged 19 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/lava/magma/compiler/var_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ class LoihiSynapseAddress(LoihiAddress):
syn_entry_id: int


@dataclass
class LoihiInAxonAddress(LoihiAddress):
# To which Profile on the core a synapse belongs
profile_id: int


@dataclass
class AbstractVarModel(ABC):
var: InitVar[Var] = None
Expand Down
53 changes: 34 additions & 19 deletions src/lava/magma/core/learning/learning_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ def __init__(
t_epoch: ty.Optional[int] = 1,
rng_seed: ty.Optional[int] = None,
) -> None:

self._dw_str = None if dw is None else "dw = " + str(dw)
self._dd_str = None if dd is None else "dd = " + str(dd)
self._dt_str = None if dt is None else "dt = " + str(dt)

# dict of string learning rules
str_learning_rules = {
str_symbols.DW: dw,
Expand Down Expand Up @@ -627,24 +632,34 @@ class Loihi3FLearningRule(LoihiLearningRule):
"""

def __init__(
self,
dw: ty.Optional[str] = None,
dd: ty.Optional[str] = None,
dt: ty.Optional[str] = None,
x1_impulse: ty.Optional[float] = 0.0,
x1_tau: ty.Optional[float] = 0.0,
x2_impulse: ty.Optional[float] = 0.0,
x2_tau: ty.Optional[float] = 0.0,
y1_impulse: ty.Optional[float] = 0.0,
y1_tau: ty.Optional[float] = 0.0,
t_epoch: ty.Optional[int] = 1,
rng_seed: ty.Optional[int] = None,
self,
dw: ty.Optional[str] = None,
dd: ty.Optional[str] = None,
dt: ty.Optional[str] = None,
x1_impulse: ty.Optional[float] = 0.0,
x1_tau: ty.Optional[float] = 0.0,
x2_impulse: ty.Optional[float] = 0.0,
x2_tau: ty.Optional[float] = 0.0,
y1_impulse: ty.Optional[float] = 0.0,
y1_tau: ty.Optional[float] = 0.0,
t_epoch: ty.Optional[int] = 1,
rng_seed: ty.Optional[int] = None,
) -> None:

super().__init__(dw=dw, dd=dd, dt=dt,
x1_impulse=x1_impulse, x1_tau=x1_tau,
x2_impulse=x2_impulse, x2_tau=x2_tau,
y1_impulse=y1_impulse, y1_tau=y1_tau,
y2_impulse=0, y2_tau=2 ** 32 - 1,
y3_impulse=0, y3_tau=2 ** 32 - 1,
t_epoch=t_epoch, rng_seed=rng_seed)
super().__init__(
dw=dw,
dd=dd,
dt=dt,
x1_impulse=x1_impulse,
x1_tau=x1_tau,
x2_impulse=x2_impulse,
x2_tau=x2_tau,
y1_impulse=y1_impulse,
weidel-p marked this conversation as resolved.
Show resolved Hide resolved
y1_tau=y1_tau,
y2_impulse=0,
y2_tau=2**32 - 1,
y3_impulse=0,
y3_tau=2**32 - 1,
t_epoch=t_epoch,
rng_seed=rng_seed,
)
53 changes: 29 additions & 24 deletions src/lava/magma/core/model/py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,34 @@
NUM_Y_TRACES = len(str_symbols.POST_TRACES)


class LearningConnection:
class AbstractLearningConnection:
"""Base class for plastic connection ProcessModels.
"""
weidel-p marked this conversation as resolved.
Show resolved Hide resolved

# Learning Ports
s_in_bap = None
s_in_y1 = None
s_in_y2 = None
s_in_y3 = None

# Learning Vars
x0 = None
tx = None
x1 = None
x2 = None

y0 = None
ty = None
y1 = None
y2 = None
y3 = None

tag_2 = None
tag_1 = None


class PyLearningConnection(AbstractLearningConnection):
"""Base class for plastic connection ProcessModels in Python / CPU.
weidel-p marked this conversation as resolved.
Show resolved Hide resolved

This class provides commonly used functions for simulating the Loihi
learning engine. It is subclasses for floating and fixed point
Expand Down Expand Up @@ -75,27 +101,6 @@ class LearningConnection:
Parameters from the ProcessModel
"""

# Learning Ports
s_in_bap = None
s_in_y1 = None
s_in_y2 = None
s_in_y3 = None

# Learning Vars
x0 = None
tx = None
x1 = None
x2 = None

y0 = None
ty = None
y1 = None
y2 = None
y3 = None

tag_2 = None
tag_1 = None

def __init__(self, proc_params: dict) -> None:
super().__init__(proc_params)

Expand Down Expand Up @@ -484,7 +489,7 @@ def _reset_dependencies_and_spike_times(self) -> None:
self.ty = np.zeros_like(self.ty)


class LearningConnectionModelBitApproximate(LearningConnection):
class LearningConnectionModelBitApproximate(PyLearningConnection):
"""Fixed-point, bit-approximate implementation of the Connection base
class.

Expand Down Expand Up @@ -1020,7 +1025,7 @@ def _saturate_synaptic_variable(
)


class LearningConnectionModelFloat(LearningConnection):
class LearningConnectionModelFloat(PyLearningConnection):
"""Floating-point implementation of the Connection Process.

This ProcessModel constitutes a behavioral implementation of Loihi synapses
Expand Down
9 changes: 8 additions & 1 deletion src/lava/magma/core/model/py/neuron.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.magma.core.model.py.ports import PyOutPort
from lava.magma.core.model.py.ports import PyOutPort, PyInPort
from lava.magma.core.model.py.type import LavaPyType
import numpy as np

Expand All @@ -19,6 +19,7 @@ class LearningNeuronModel(PyLoihiProcessModel):
"""

# Learning Ports
a_third_factor_in = None
s_out_bap = None
s_out_y1 = None
s_out_y2 = None
Expand Down Expand Up @@ -51,6 +52,10 @@ class LearningNeuronModelFixed(LearningNeuronModel):
"""

# Learning Ports
a_third_factor_in: PyInPort = LavaPyType(
PyInPort.VEC_DENSE, np.int32, precision=7
)

s_out_bap: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, bool, precision=1)
s_out_y1: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, np.int32, precision=7)
s_out_y2: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, np.int32, precision=7)
Expand Down Expand Up @@ -80,6 +85,8 @@ class LearningNeuronModelFloat(LearningNeuronModel):
"""

# Learning Ports
a_third_factor_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, float)

s_out_bap: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, bool)
s_out_y1: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float)
s_out_y2: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float)
Expand Down
9 changes: 5 additions & 4 deletions src/lava/magma/core/process/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,14 @@ class LearningConnectionProcess(AbstractProcess):
"""
def __init__(
self,
shape: tuple = (1, 1),
shape: tuple,
learning_rule: ty.Optional[LoihiLearningRule] = None,
**kwargs,
):
kwargs["learning_rule"] = learning_rule

kwargs["shape"] = shape
tag_1 = kwargs.get('tag_1', 0)
tag_2 = kwargs.get('tag_2', 0)

self.learning_rule = learning_rule

Expand All @@ -83,7 +84,7 @@ def __init__(
self.y2 = Var(shape=(shape[0],), init=0)
self.y3 = Var(shape=(shape[0],), init=0)

self.tag_2 = Var(shape=shape, init=0)
self.tag_1 = Var(shape=shape, init=0)
self.tag_1 = Var(shape=shape, init=tag_1)
self.tag_2 = Var(shape=shape, init=tag_2)

super().__init__(**kwargs)
4 changes: 3 additions & 1 deletion src/lava/magma/core/process/neuron.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import typing as ty
from lava.magma.core.process.ports.ports import OutPort
from lava.magma.core.process.ports.ports import OutPort, InPort
from lava.magma.core.learning.learning_rule import LoihiLearningRule
from lava.magma.core.process.variable import Var

Expand Down Expand Up @@ -31,6 +31,8 @@ def __init__(self,
kwargs['learning_rule'] = learning_rule

# Learning Ports
self.a_third_factor_in = InPort(shape=(shape[0],))

# Port for backprop action potentials
self.s_out_bap = OutPort(shape=(shape[0],))

Expand Down
4 changes: 2 additions & 2 deletions src/lava/proc/learning_rules/stdp_learning_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def __init__(

# Other learning-related parameters
# Trace impulse values
x1_impulse = kwargs.get("x1_impulse", 16)
y1_impulse = kwargs.get("y1_impulse", 16)
x1_impulse = kwargs.pop("x1_impulse", 16)
y1_impulse = kwargs.pop("y1_impulse", 16)

# Trace decay constants
x1_tau = tau_plus
Expand Down