Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Examples of regression? #98

Closed
shilpakancharla opened this issue Mar 31, 2022 · 29 comments
Closed

Examples of regression? #98

shilpakancharla opened this issue Mar 31, 2022 · 29 comments

Comments

@shilpakancharla
Copy link

I was wondering if anyone had used snnTorch for regression, and perhaps how you set your networks up. Just looking for simple, general examples! MSELoss would likely be the type of loss used as I see it.

@ahenkes1
Copy link
Collaborator

I would second this question, as it is not obvious for me how to decode spikes to the real line. If you decode rate based, then the prediction for large numbers would take more spikes (=more expensive?). A simple example would be great!

@mahmad2005
Copy link

Have you got any example of Linear Regression using snnTorch? I'm trying to test simple f(x) = x linear regression problem using ssnTorch, following the tutorial 5, but could not improve the training loss. I'm using MSELoss as a loss function and Stochastic gradient descent SGD as an optimizer.

@shilpakancharla
Copy link
Author

I actually used MSELoss and Adam in one of my projects albeit I'm sure it could still use some work (see my repo on event-based velocity prediction). I've also been trying to do some literature search of what appropriate loss and optimization functions are for SNNs in general.

@mahmad2005
Copy link

I also used Adam but could not get model improved. I think my code has some problem. I have posted about in on discussion section if you or someone could help me to find where I made mistake.
#122

@jeshraghian
Copy link
Owner

The way I see it, there are a few ways to perform regression.

  1. Set the target of the membrane potential of a spiking neuron to reach the desired value
  2. Set the target of the total spike count at the end of the simulation to reach the desired value
  3. Set the target of the spike time to reach the desired value

The most effective way to implement '1' would be to ensure the output layer has the reset mechanism disabled (e.g., snn.Leaky(beta=beta, reset="none") or by setting the threshold to an arbitrarily large value. Then you would need to decide at what point in time you will be measuring the output. E.g., are you only concerned with the membrane potential at the final time step, or for all time steps? This will be quite task-dependent.

The approach for '2' is quite straightforward, as you would simply sum together all of the spikes and try and set it to be the desired target value. The limiting factors here are i) the quantised nature of spike counts (i.e., it can only take on discrete values / natural numbers, and ii) the maximum permissible spike count is the total number of time steps. The second issue can be lessened by using multiple neurone to emit multiple spikes.

The approach for '3' is lesser explored in the context of a PyTorch backend, and in general, I find it is far less stable than rate-based loss functions. But I've put together a loss function with a few examples: mse_temporal_loss. I found that I often had to drop the threshold significantly to start seeing any action, especially as my networks became deeper.

https://snntorch.readthedocs.io/en/latest/snntorch.functional.html#snntorch.functional.loss.mse_temporal_loss

Let me know what types of problems you're trying to tackle, and I can account for it in future tutorials.

@ahenkes1
Copy link
Collaborator

ahenkes1 commented Aug 8, 2022

Thank you very much for the extensive answer! To begin with, a simple affine transformation of a temporal (linear) sequence would be helpful. That is,

x = [0, 1, 2, ..., t] / t (that is, linspace normalized between 0 and 1)

y = ax +b for some real numbers a,b (possibly chosen randomly for every sample)

I experimented with unrolled spiking lstm in combination with classical dense layers as a wrapper (also using different libraries), where I was able to overfit on small values of t, but at no extend to a precision like classical lstms.

So basically, I am stuck at trying to overfit on a single sample with, say, 25 - 100 timesteps for fixed values of a and b.

@jeshraghian
Copy link
Owner

This should definitely be possible! Are you able to apply MSELoss() to the output membrane potential?

I have a notebook in another repo where I train the membrane potential to linearly increase over time given random inputs. I expect learning a linear mapping is quite easier. Check out the ipynb file in this repo for inspo:

https://github.com/jeshraghian/snn-tha

@ahenkes1
Copy link
Collaborator

ahenkes1 commented Aug 8, 2022

Thank you for the feedback! I just used the normal MSE in the optimizer when I had a standard 'Linear' layer as output. I'll check your file and report back!

@ahenkes1
Copy link
Collaborator

ahenkes1 commented Aug 8, 2022

It seems to work in principal (your approach 1 using membrane potentials), but I have some (explainable) artifact in my results. I try to overfit on a single vector [TIME=10, BATCH=1, FEATURE=1], which is a simple linspace (True):

True            Prediction      
[('4.9626e-01', '0.0000e+00'),
 ('5.2647e-01', '0.0000e+00'),
 ('5.5669e-01', '0.0000e+00'),
 ('5.8691e-01', '0.0000e+00'),
 ('6.1713e-01', '6.1499e-01'),
 ('6.4735e-01', '6.4795e-01'),
 ('6.7757e-01', '6.7761e-01'),
 ('7.0779e-01', '7.0775e-01'),
 ('7.3800e-01', '7.3803e-01'),
 ('7.6822e-01', '7.6820e-01')]

I use the following network:

real input -> leaky integrator with linear weights -> leaky integrate and fire with linear weights -> leaky integrator with linear weights -> real output.

Now, how to get the "correct" results for the prediction also for the first iterations? I think in the beginning the membrane is not saturated enough ....

@jeshraghian
Copy link
Owner

jeshraghian commented Aug 8, 2022 via email

@ahenkes1
Copy link
Collaborator

ahenkes1 commented Aug 8, 2022

I used a different lib for the experiment, I will make a small script in snntorch for the whole code for everyone to check out and post it here. And yes, I mean "time steps". Maybe one could use dummy entries for the first few time steps?

@jeshraghian
Copy link
Owner

jeshraghian commented Aug 8, 2022 via email

@ahenkes1
Copy link
Collaborator

ahenkes1 commented Aug 8, 2022

Thanks for pointing out the paper! I already used 1024 as width, which seems to be quite an overkill for such a simple task, so I thought there would be more of a fundamental error :D

@ahenkes1
Copy link
Collaborator

ahenkes1 commented Aug 8, 2022

Well, I implemented my architecture from the other lib into snntorch and .... it worked?!? Here is my code:

"""A simple regression task using snntorch."""
import snntorch
import torch
import torch.utils.data


class Regression_dataset(torch.utils.data.Dataset):
    """Simple regression dataset."""

    def __init__(self, timesteps):
        """Linear relation between input and output"""
        lin_vec = torch.linspace(start=0.0, end=0.1, steps=timesteps)
        self.feature = lin_vec.view(timesteps, 1, 1)
        self.label = self.feature * 10

    def __len__(self):
        """Only one sample."""
        return 1

    def __getitem__(self, idx):
        """General implementation, but we only have one sample."""
        return self.feature[:, idx, :], self.label[:, idx, :]


class SNN(torch.nn.Module):
    """Simple spiking neural network in snntorch."""

    def __init__(self, timesteps, hidden):
        super().__init__()
        self.timesteps = timesteps
        self.hidden = hidden

        self.fc1 = torch.nn.Linear(in_features=1, out_features=self.hidden)
        self.lif = snntorch.Leaky(beta=0.5)
        self.fc2 = torch.nn.Linear(in_features=self.hidden, out_features=1)
        self.li = snntorch.Leaky(beta=0.5, reset_mechanism="none")

    def forward(self, x):
        """Forward pass for 10 time steps."""
        mem1 = self.lif.init_leaky()
        mem2 = self.li.init_leaky()

        cur3_rec = []
        mem2_rec = []

        for step in range(self.timesteps):
            cur1 = self.fc1(x[step, :, :])
            spk1, mem1 = self.lif(cur1, mem1)
            cur2 = self.fc2(spk1)
            cur3, mem2 = self.li(cur2, mem2)
            cur3_rec.append(cur3)
            mem2_rec.append(mem2)

        return torch.stack(cur3_rec, dim=0), torch.stack(mem2_rec, dim=0)


def main():
    """Training loop and prediction."""
    DEVICE = "cuda"
    TIMESTEPS = 11
    ITER = 2000
    HIDDEN = 1024

    dataloader = torch.utils.data.DataLoader(
        dataset=Regression_dataset(timesteps=TIMESTEPS)
    )

    model = SNN(timesteps=TIMESTEPS, hidden=HIDDEN).to(DEVICE)
    model.train()

    optimizer = torch.optim.Adam(params=model.parameters(), lr=3e-4)
    loss_function = torch.nn.MSELoss()

    feature = None
    label = None
    loss_val = None
    for i in range(ITER):
        train_batch = iter(dataloader)

        for feature, label in train_batch:
            feature = torch.swapaxes(input=feature, axis0=0, axis1=1)
            label = torch.swapaxes(input=label, axis0=0, axis1=1)
            feature = feature.to(DEVICE)
            label = label.to(DEVICE)

            cur, mem = model(feature)

            loss_val = loss_function(mem, label)
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()

        print(f"Iter: {i}, Loss: {loss_val.detach().cpu().numpy()}")

    with torch.no_grad():
        model.eval()
        _, prediction = model(feature)

    label = torch.squeeze(label).cpu().numpy().tolist()
    prediction = torch.squeeze(prediction).cpu().numpy().tolist()
    result = list(zip(label, prediction))
    for i in result:
        print(i)

    return None


if __name__ == "__main__":
    main()

This are of course great news, but now I have to figure out, what exactly happened and why the other code is not working. By the way, for large TIMESTEPS I am not able to overfit, I think due to vanishing gradients. I will try your LSTM implementation on that.

Best

@jeshraghian
Copy link
Owner

jeshraghian commented Aug 8, 2022 via email

@ahenkes1
Copy link
Collaborator

ahenkes1 commented Aug 9, 2022

I did some experiments with regard to the LSTM. I observed the following points:

  1. If the threshold for the LSTM is not set to trainable, the loss will converge to a bad level.
  2. If the threshold is trainable, I can reach losses in the order of O(1e-5), which is much worse than the architecture consisting only of LIF.
  3. If I reduce the width of the LSTM to 512, I can overfit to machine precision. The LSTM consists of multiple networks, so the complexity is higher, as well as the difficulty to train. Reducing makes sense.
  4. reset_mechanism does nothing with respect to the achievable loss.
  5. This is all for 10 time steps. My expectation was, like in standard LSTM which work just fine, that the spiking LSTM will be able to overfit on even larger time steps, which is unfortunately not the case.

Do you have an idea why this happens or how to improve the LSTM?

P.S.: Concerning the other libs, I will dig into the implementation, write down the formulae and try to come up with a reasonable interpretation.

Here is the code to reproduce, you can switch between LSTM and LIF in the first layer by setting LSTM=True/False:

"""A simple regression task using snntorch."""
import numpy
import random
import snntorch
import snntorch.surrogate
import torch
import torch.utils.data

# Seed
torch.manual_seed(0)
random.seed(0)
numpy.random.seed(0)


class Regression_dataset(torch.utils.data.Dataset):
    """Simple regression dataset."""

    def __init__(self, timesteps):
        """Linear relation between input and output"""
        lin_vec = torch.linspace(start=0.0, end=1.0, steps=timesteps)
        self.feature = lin_vec.view(timesteps, 1, 1)
        self.label = self.feature * 1

    def __len__(self):
        """Only one sample."""
        return 1

    def __getitem__(self, idx):
        """General implementation, but we only have one sample."""
        return self.feature[:, idx, :], self.label[:, idx, :]


class SNN(torch.nn.Module):
    """Simple spiking neural network in snntorch."""

    def __init__(self, timesteps, hidden, lstm=False):
        super().__init__()
        self.timesteps = timesteps
        self.hidden = hidden
        self.lstm = lstm

        spike_grad = snntorch.surrogate.atan()

        if not self.lstm:
            beta_in = torch.rand(self.hidden)
            thr_in = torch.rand(self.hidden)

            self.fc1 = torch.nn.Linear(in_features=1, out_features=self.hidden)
            self.lif = snntorch.Leaky(
                beta=beta_in,
                threshold=thr_in,
                learn_beta=True,
                learn_threshold=True,
                spike_grad=spike_grad,
                reset_mechanism="subtract",
            )

        elif self.lstm:
            thr_lstm = torch.rand(self.hidden)

            self.slstm = snntorch.SLSTM(
                input_size=1,
                hidden_size=self.hidden,
                spike_grad=spike_grad,
                learn_threshold=True,
                threshold=thr_lstm,
                reset_mechanism="none",
            )

        else:
            raise SystemExit()

        beta_out = torch.rand(1)
        thr_out = torch.rand(1)

        self.fc2 = torch.nn.Linear(in_features=self.hidden, out_features=1)
        self.li = snntorch.Leaky(
            beta=beta_out,
            threshold=thr_out,
            learn_beta=True,
            learn_threshold=True,
            spike_grad=spike_grad,
            reset_mechanism="none",
        )

    def forward(self, x):
        """Forward pass for several time steps."""
        syn_in = None

        if not self.lstm:
            mem_in = self.lif.init_leaky()

        elif self.lstm:
            syn_in, mem_in = self.slstm.init_slstm()

        else:
            raise SystemExit()

        mem_out = self.li.init_leaky()

        cur_out_rec = []
        mem_out_rec = []

        for step in range(self.timesteps):
            x_timestep = x[step, :, :]

            if not self.lstm:
                cur_in = self.fc1(x_timestep)
                spk_in, mem_in = self.lif(cur_in, mem_in)

            elif self.lstm:
                spk_in, syn_in, mem_in = self.slstm(x_timestep, syn_in, mem_in)

            else:
                raise SystemExit()

            cur_out = self.fc2(spk_in)
            cur_out, mem_out = self.li(cur_out, mem_out)
            cur_out_rec.append(cur_out)
            mem_out_rec.append(mem_out)

        return torch.stack(cur_out_rec, dim=0), torch.stack(mem_out_rec, dim=0)


def main():
    """Training loop and prediction."""
    DEVICE = (
        torch.device("cuda")
        if torch.cuda.is_available()
        else torch.device("cpu")
    )
    TIMESTEPS = 10
    ITER = 2000
    HIDDEN = 1024
    LSTM = True

    dataloader = torch.utils.data.DataLoader(
        dataset=Regression_dataset(timesteps=TIMESTEPS)
    )

    model = SNN(timesteps=TIMESTEPS, hidden=HIDDEN, lstm=LSTM).to(DEVICE)
    model.train()

    optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
    loss_function = torch.nn.MSELoss()

    feature = None
    label = None
    loss_val = None
    for i in range(ITER):
        train_batch = iter(dataloader)

        for feature, label in train_batch:
            feature = torch.swapaxes(input=feature, axis0=0, axis1=1)
            label = torch.swapaxes(input=label, axis0=0, axis1=1)
            feature = feature.to(DEVICE)
            label = label.to(DEVICE)

            cur, mem = model(feature)

            loss_val = loss_function(mem, label)
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()

        print(f"Iter: {i}, Loss: {loss_val.detach().cpu().numpy()}")

    with torch.no_grad():
        model.eval()
        _, prediction = model(feature)

    label = torch.squeeze(label).cpu().numpy().tolist()
    prediction = torch.squeeze(prediction).cpu().numpy().tolist()
    result = list(zip(label, prediction))
    for i in result:
        print(i)

    return None


if __name__ == "__main__":
    main()

@jeshraghian
Copy link
Owner

jeshraghian commented Aug 9, 2022 via email

@ahenkes1
Copy link
Collaborator

ahenkes1 commented Aug 9, 2022

I tried different widths, thresholds, learnable/not learnable thresholds, reset mechanism and surrogate gradients. Nothing seems to work for timesteps > 10 ....

@ahenkes1
Copy link
Collaborator

ahenkes1 commented Aug 9, 2022

However, I am able to overfit, if I pass the membrane potential of the slstm to the LI layer instead of the spikes, which makes me wonder if the problem really lies in the LSTM or maybe in the LI layer? Does this count as cheating and going around the idea of spikes / spiking neural networks? :D

@jeshraghian
Copy link
Owner

jeshraghian commented Aug 10, 2022 via email

@ahenkes1
Copy link
Collaborator

Thank you for the explanation, the engineering point of view suites me well ;)

What do you think, do you want my simple example to be added in your tutorial section? We could also try something more difficult, like f(x) = x^2, I dont know ....We could close this "issue" and open a thread in the discussion area, maybe to prepare a PR?

@jeshraghian
Copy link
Owner

jeshraghian commented Aug 11, 2022 via email

@shilpakancharla
Copy link
Author

Just read through the discussion. I'd love to contribute to the tutorial or join in on the discussion as it's written - I've been using snnTorch quite frequently for my work and this would be a great opportunity for me too I think!

@shilpakancharla
Copy link
Author

I think another interesting tutorial that could be done is using event data (timestamp, x, y, polarity) with regression in order to predict something about what is being captured by an event camera. Perhaps this is a bit more advanced regression problem, but I did my Masters thesis recently on this and have open-sourced my dataset, or there are a ton of existing ones in Tonic that could maybe be used albeit I'm not sure they have regression applications. Interested in hearing your thoughts or other potential ideas for a more real-world example of regression with SNNs.

@ahenkes1
Copy link
Collaborator

I am currently working on a baseline code. Unfortunately, I have a memory leak which needs to be fixed, for which I have tank experience in pytorch. If everything is working, I'll report back!

@jeshraghian
Copy link
Owner

jeshraghian commented Oct 11, 2022 via email

@shilpakancharla
Copy link
Author

Thats awesome, lets do it. Do you have discord? We're currently brainstorming in the tutorial-dev channel there. I haven't used input spike times directly (without expanding them out to tensors), but I'm definitely intrigued!

On Thu, 11 Aug 2022, 11:52 pm Shilpa Kancharla, @.> wrote: I think another interesting tutorial that could be done is using event data (timestamp, x, y, polarity) with regression in order to predict something about what is being captured by an event camera. Perhaps this is a bit more advanced regression problem, but I did my Masters thesis recently on this and have open-sourced my dataset, or there are a ton of existing ones in Tonic that could maybe be used albeit I'm not sure they have regression applications. Interested in hearing your thoughts or other potential ideas for a more real-world example of regression with SNNs. — Reply to this email directly, view it on GitHub <#98 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AJTFT4QMXSP57XDCMJUU7IDVYUOUPANCNFSM5SGMR4BA . You are receiving this because you commented.Message ID: @.>

I do have Discord ~ I'm currently away on work but I will contact you as soon as I get back, I'd love to be part of that channel.

@shilpakancharla
Copy link
Author

What's the name of the discord?

@jeshraghian
Copy link
Owner

Oops, teaching quarter just finished and I'm catching up on life now.
The channel name is just snnTorch & the link is in the readme "Chat" badge.
Converting this to a discussion in the meantime!

Repository owner locked and limited conversation to collaborators Dec 11, 2022
@jeshraghian jeshraghian converted this issue into discussion #166 Dec 11, 2022

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants