diff --git a/snntorch/_neurons/rleaky.py b/snntorch/_neurons/rleaky.py index 853c6d51..96e3d580 100644 --- a/snntorch/_neurons/rleaky.py +++ b/snntorch/_neurons/rleaky.py @@ -242,6 +242,8 @@ def detach_hidden(cls): for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], RLeaky): cls.instances[layer].mem.detach_() + cls.instances[layer].spk.detach_() + @classmethod def reset_hidden(cls): @@ -250,4 +252,5 @@ def reset_hidden(cls): Assumes hidden states have a batch dimension already.""" for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], RLeaky): - cls.instances[layer].mem = _SpikeTensor(init_flag=False) + cls.instances[layer].spk, cls.instances[layer].mem = cls.instances[layer].init_rleaky() +