Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with inputing custom weights for Rate based SNN #66

Closed
kannanum opened this issue Oct 1, 2021 · 4 comments
Closed

Issue with inputing custom weights for Rate based SNN #66

kannanum opened this issue Oct 1, 2021 · 4 comments

Comments

@kannanum
Copy link

kannanum commented Oct 1, 2021

  • snntorch version: 0.4.4
  • Python version: 3.9.6
  • Operating System: Ubuntu 20.04.3 LTS

Description

Hi Jason,
First of all, I appreciate your wonderful effort in developing this package and a detailed documentation. I have recently started using snntorch for rate based SNN coding. Although I am getting good performance for purely software based run, I am facing issues with inputing custom weights extracted from a synaptic device. My accuracy is getting stuck at around 10% which is the same as the untrained accuracy.

What I Did

I used a custom function to input the weights from a text file as shown in the screenshot. Please let me know how to solve this issue.
NB: I am pretty new to programming. so pls excuse me if my code is too cumbersome :)
Capture
`
Here is the full file and the text file for data input
rate_SNN_dev_weights.zip

Thanks,
Kannan

@jeshraghian
Copy link
Owner

Hi Kannan,

Really appreciate the kind words :)
I don't think there's anything wrong with your code from running through it. That looks good to me.
I'm guessing it might be that the weight initialization override makes it difficult for any learning to take place.
E.g., if the weights are too small, and the neuron spiking threshold is too high, then no learning can take place (dL/dW = 0).
There are a few ways you can overcome this:

spike_grad = surrogate.fast_sigmoid(slope=25)
self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
  • Lower the threshold. Default value is threshold=1. If none of your neurons are firing, then perhaps this can help too. You should take a look at your output neurons to see if any neurons are firing. Perhaps plot their membrane potentials too; if it's far below the threshold, then it is likely that your weights aren't large enough or distributed appropriately.
self.lif1 = snn.Leaky(beta=beta, threshold=0.3)

If none of the above work, then you may need to scale or normalize your weights (or inputs).
In my own experience, training deep networks using device-derived values can become really challenging.
In general, it seems like the variance of the activations in the forward pass / of the gradients in the backward pass should be as close to 1 as possible to avoid vanishing & exploding gradients. I don't think this is your issue, as your network is quite shallow (1 hidden layer). But something to consider if you plan to scale your experiments up.

I also guess that you have no inhibitory connections (i.e., all of your weights are positive). This might be okay for simple tasks, but you might consider having a hidden layer size of 42*2 =84. Where 42 of your weights are initialized positive, and the other 42 are negative.

It sounds like you're working on a really cool project too - most of my work is with RRAM devices which seems kind of relevant here, so feel free to update me on how you go!

@kannanum
Copy link
Author

kannanum commented Oct 4, 2021

Hi Jason,
Thanks for your detailed reply. I was worried whether my spiking input implementation was correct or not. I have already tried with surrogate gradient method but there too I had the same issue. But I will try changing the threshold firing potential and come up with a more uniform weight distribution. I will get back to you with the results. Happy to know that you are also working in RRAM devices. Hope to catch up with you some time on RRAM discussions.
I tried to do away with the inhibitory connections to reduce the hardware array size and also as you pointed out- the task was pretty simple. But, later I also tried to scale the weights in [-1,1] but could not see much difference.

@jeshraghian
Copy link
Owner

I've been playing with your code a bit more. Your train/test loss are both decreasing, but as you've observed, the accuracy has struggled to decrease.

This is a plot of 9 of 10 of the output neurons, and their membrane explodes; the values below exceed 100 (threshold=1 means spiking).

image

So I would guess you're facing an exploding gradient problem. This might make sense as the device weights are positively distributed (activations are becoming too positive).

I managed to (unexplainably) fix the issue by using the raw MNIST input (instead of spikes), and that was able to successfully train. After pre-training it on the raw input, I was then able to train it on the spiking version as well (50%+ before exiting the process).

Alternative approaches that I would test out would be anything that can prevent the membrane potential from exploding:

  • decrease alpha and/or beta
  • equivalently, use snn.Leaky instead. The main impact of snn.Synaptic is that by having two decay rates, it increases the overall time constant. This could be causing membrane potential to accumulate far too much
  • try reset_mechanism = zero to force the membrane potential to zero; this could help reduce the membrane potential / prevent exploding
  • try using a MSE Loss applied to the membrane potential. Cross-Entropy Loss can sometimes let the membrane increase indefinitely and may be unstable. MSE can set a target for each membrane potential to try and hone in on a given value, to promote stability. I actually tested this out, and on its own, it helped reduce the loss but it was not enough on its own to help with accuracy.

Happy to chat further on RRAM, feel free to reach out! Contact details are on my website.

@jeshraghian
Copy link
Owner

Closing due to inactivity & I believe the issue was solved.
Custom device weights worked, though were likely leading to exploding gradients.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants