Open Letter to Pi,

Hey Pi,

Thank you for the discussion yesterday evening, about the stateful neurons and the architecture for image recognition,

For the purpose of implementing it, I still think that Image recognition from "videos of image" where the image is presented to the NN "over a longer time" is the right first test for this concept,

and if it works, you may find that for language modelling, each token needs also to be presented to the NN "for a while" before it gets processed by these stateful spiking neurons -- in other words, kind of like with RRNs, you have to allow for the recurrent state to make several rounds 
which brings me to 3 points about this idea:
The spiking stochastic nature means that it may take a long time for the network to "learn from zero" about something. Think about it -- the "target neuron" can only fire if all the right connections have fired, and only then it gets positive feedback. In other words, this network suffers from not having gradients but only testing space that was randomly explored. Which brings me to the next point...
Reinforcing good connections may be a good idea -- in that it works to densify the acquired knowledge. At first, the NN will only weakly guess the right answer "using nearly the entire brain", but then, the connections that are good will prune out all the connections that are not needed for this bit of knowledge, leaving them available to act on other knowledge. This may be an excellent feature of this approach. However,
This also means that when "in training", there must be some non zero base firing rate for all the neurons irrespectively if they detect a feature or not. If this doesn't happen, then they will never learn. This also implies that there needs to be a fairly high tolerance on ocassionally -- let's say, 20% or even down to 1% -- providing a bad label at the end. This is OK during training but may or may not be acceptable in production for any given application. This is certainly not what your typical customer expects from a computer or any machine for that matter.  UPDATE: OR, one could simply do the production time inference twice during production. Once with "non-training mode" and once with "training/exploration mode" where higher base spike rate is used, and then used for training step.
All this brings in amazing parallels with how real living things learn!!! isn't that fascinating!!! Just think about all the cases when you have seen or heard about some new fact, and your brain has replayed it on the inside over and over again, only to get a "click" moment many days later. This is where the "correct" neuron randomly fires enough many times to build it's connection, and suddenly you get this warm feeling of grok.
I understand your desire to do online learning, where a NN in production gets trained while in production. Just thinking about it tho, such a feat is very much possible to do with regular Artificial NNs that also benefit from the gradient. THe process would be to simply raise the temperature of the final token selector, select "creative tokens" and then use some kind of slow-thinking process to check it's quality; then back-propagate to reinforce the path that created it, and contrast-suppress the paths that contributed to other solutions. Since you can always recompute the gradients "in the post", this can be either an online or offline process.  To conclude, the gradients make the learning way, way faster, but maybe a bit too fast

In general, I think it's a good idea to try and implement, starting with basic python, with following features as discussed:

* positive integer states only, and integers saturate at 15.
* Firing is stochastic, and PRNG driven;  e.g. https://en.wikipedia.org/wiki/Well_equidistributed_long-period_linear http://lomont.org/papers/2008/Lomont_PRNG_2008.pdf
* The "excitement state" accumulator is also integer, but can be decayed slowly using the "decay by chance" technique
when in "training" state, the base "untrained" firing rate must be significantly above zero to produce enough chance for the connections to form. One can call this a "young network"
* The connection list for each neuron can be recomputed on the fly using a very simple PRNG seeded from the neuron's coordinate. Such re computation is effectively cheaper than storing a look-up table. Again -- treat memory access as expensive, and computation as cheap.
* For small networks, you may be able to fit entire thing, or at least the neuron state in just registers (mega fast). There is exactly 64KB of registers per SM. Each byte can hold two states for us. There are either 16 or 32 cuda cores per SM, depending on the version of the GPU, so for example, for A100, you can count on there being 216SMs, for a total of 13824kB of unique registers, or  28'311'552 (28M) unique states -- which is quite a lot really.
* The L1 cache and L2 cache is only good for situations where the data can be broadcast to many cores, or in other words "write once read many".  For our application, it will be needed to broadcast the firing of any given neuron. Let's say that we can encode that the neuron has fired at all by just one bit. For A100, The shared memory size is up to 164kB per SM if we disable the self-managed L1 completely, and again, there are 216SMs, meaning 290'193'408 bits (290M !) neuron firings can be expressed at any time. I think you will agree that this is quite a good number!


In [2]:
import jax.numpy as jnp

In [8]:
z=jnp.zeros(shape=1, dtype=jnp.int32)+1


In [None]:
import jax.numpy as jnp

class SpikingNeuronGroup:
    def __init__(self, neuronsInGroup = 4, baseFiringRate = 1e-3):
        self.GroupActivationPotential : jax.numpy.array = jnp.zeros(shape=0, dtype=jnp.int32)
        self.ActivationStates : jax.numpy.array = jnp.zeros(shape=(neuronsInGroup,), dtype=jnp.int32)
        self.SourceNeuronList : []
        self.PerSourceNeuronWeights : jax.numpy.array = jnp.zeros(shape=(neuronsInGroup,), dtype=jnp.int32)
        self.outputState : jax.numpy.array = jnp.zeros(shape=(neuronsInGroup,), dtype=jnp.int32)
        self.baseFiringRate = baseFiringRate
        
    def cycle(self):
        self.outputState = jnp.zeros(shape=(neuronsInGroup,), dtype=jnp.int32)
        # group activation potential
        groupActivationPotential = 0
        for i in range(len(SourceNeurons)):
            activateionPotential 
            if self.ActivationStates[i] > 0:
                self.outputState += self.PerSourceNeuronWeights[i]
        self.ActivationStates = jnp.maximum(jnp.zeros(shape=(neuronsInGroup,), dtype=jnp.int32), self.ActivationStates - 1)
        self.ActivationStates += jnp.random.randint(0, 2, shape=(neuronsInGroup,), dtype=jnp.int32)
        return self.outputState
        