From 3e502dea4a8b4dd94505fc8d2ac2d78c75368d3a Mon Sep 17 00:00:00 2001 From: Manuel Breitenstein Date: Mon, 16 May 2022 15:18:08 +0200 Subject: [PATCH] Add spikes in detach_hidden and reset_hidden for rleaky. --- snntorch/_neurons/rleaky.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/snntorch/_neurons/rleaky.py b/snntorch/_neurons/rleaky.py index 853c6d51..96e3d580 100644 --- a/snntorch/_neurons/rleaky.py +++ b/snntorch/_neurons/rleaky.py @@ -242,6 +242,8 @@ def detach_hidden(cls): for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], RLeaky): cls.instances[layer].mem.detach_() + cls.instances[layer].spk.detach_() + @classmethod def reset_hidden(cls): @@ -250,4 +252,5 @@ def reset_hidden(cls): Assumes hidden states have a batch dimension already.""" for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], RLeaky): - cls.instances[layer].mem = _SpikeTensor(init_flag=False) + cls.instances[layer].spk, cls.instances[layer].mem = cls.instances[layer].init_rleaky() +