Skip to content

Commit

Permalink
Persistent delay buffer to correctly run slayer blocks one timestep a…
Browse files Browse the repository at this point in the history
…t a time (#169)

* CI fix for poetry on windows error

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* removed a caching option

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* Reverted caching changes

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* Reverted caching changes

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* pipx intallation of poetry

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* removed cache in python installation

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* Testing testing

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* Updated CI

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* Change to python v4

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* Remove poetry cache

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* comments cleanup

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* Updated Readme to point to lava-dl decolle

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* language update

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* Updated illustration image

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* Updated readme

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* Updated benchmarking notebooks to use new callback_fxs api

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* Fixed netx.hdf5 to handle scale dimensions on input.

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* Linting fixes

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* Restore data from lfs

* Temp removing file

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* added mnist network file

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* added lfs git attributes

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* Delay buffer added

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* Fixed quantization range

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

* Dummy change

Signed-off-by: bamsumit <bam_sumit@hotmail.com>

---------

Signed-off-by: bamsumit <bam_sumit@hotmail.com>
  • Loading branch information
bamsumit committed Mar 3, 2023
1 parent 4a1ac7c commit c12d553
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 10 deletions.
43 changes: 35 additions & 8 deletions src/lava/lib/dl/slayer/block/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,33 @@
from lava.lib.dl.slayer.utils import recurrent


def step_delay(module, x):
"""Step delay computation. This simulates the 1 timestep delay needed
for communication between layers.
Parameters
----------
module: module
python module instance
x : torch.tensor
Tensor data to be delayed.
"""
if hasattr(module, 'delay_buffer') is False:
module.delay_buffer = None
persistent_state = hasattr(module, 'neuron') \
and module.neuron.persistent_state is True
if module.delay_buffer is not None:
if module.delay_buffer.shape[0] != x.shape[0]: # batch mismatch
module.delay_buffer = None
if persistent_state:
delay_buffer = 0 if module.delay_buffer is None else module.delay_buffer
module.delay_buffer = x[..., -1]
x = delay(x, 1)
if persistent_state:
x[..., 0] = delay_buffer
return x


class AbstractInput(torch.nn.Module):
"""Abstract input block class. This should never be instantiated on its own.
Expand Down Expand Up @@ -88,7 +115,7 @@ def forward(self, x):
x = self.neuron(z)

if self.delay_shift is True:
x = delay(x, 1)
x = step_delay(self, x)

if self.input_shape is None:
if self.neuron is not None:
Expand Down Expand Up @@ -515,7 +542,7 @@ def forward(self, x):
z = self.synapse(x)
x = self.neuron(z)
if self.delay_shift is True:
x = delay(x, 1)
x = step_delay(self, x)
if self.delay is not None:
x = self.delay(x)

Expand Down Expand Up @@ -670,7 +697,7 @@ def forward(self, x):
z = self.synapse(x)
x = self.neuron(z)
if self.delay_shift is True:
x = delay(x, 1)
x = step_delay(self, x)
if self.delay is not None:
x = self.delay(x)

Expand Down Expand Up @@ -822,7 +849,7 @@ def forward(self, x):
z = self.synapse(x)
x = self.neuron(z)
if self.delay_shift is True:
x = delay(x, 1)
x = step_delay(self, x)
if self.delay is not None:
x = self.delay(x)

Expand Down Expand Up @@ -962,7 +989,7 @@ def forward(self, x):
z = self.synapse(x)
x = self.neuron(z)
if self.delay_shift is True:
x = delay(x, 1)
x = step_delay(self, x)
if self.delay is not None:
x = self.delay(x)

Expand Down Expand Up @@ -1114,7 +1141,7 @@ def forward(self, x):
z = self.synapse(x)
x = self.neuron(z)
if self.delay_shift is True:
x = delay(x, 1)
x = step_delay(self, x)
if self.delay is not None:
x = self.delay(x)

Expand Down Expand Up @@ -1320,7 +1347,7 @@ def forward(self, x):
# self.spike_state = spike.clone().detach().reshape(z.shape[:-1])

if self.delay_shift is True:
x = delay(x, 1)
x = step_delay(self, x)

if self.count_log is True:
return x, torch.mean(x > 0)
Expand Down Expand Up @@ -1442,7 +1469,7 @@ def forward(self, x):
self.spike_state = spike.clone().detach().reshape(z.shape[:-1])

if self.delay_shift is True:
x = delay(x, 1)
x = step_delay(self, x)
if self.delay is not None:
x = self.delay(x)

Expand Down
4 changes: 2 additions & 2 deletions src/lava/lib/dl/slayer/neuron/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,11 @@ def quantize_8bit(self, weight, descale=False):
if descale is False:
return quantize(
weight, step=2 / self.w_scale
).clamp(-256 / self.w_scale, 255 / self.w_scale)
).clamp(-256 / self.w_scale, 254 / self.w_scale)
else:
return quantize(
weight, step=2 / self.w_scale
).clamp(-256 / self.w_scale, 255 / self.w_scale) * self.w_scale
).clamp(-256 / self.w_scale, 254 / self.w_scale) * self.w_scale

@property
def weight_exponent(self):
Expand Down

0 comments on commit c12d553

Please sign in to comment.