In [3]:
import torch
from matplotlib import pyplot as plt

In [None]:
"""
Idea: 
Build a synapse that exhibits cooperative bistability through spine neck resistance

We need:
an AMPAR conductance in the synapse for converting input to current
an NMDAR conductance in the synapse for calcium (and slower depolarizatiokns?)
a voltage/calcium dependent spine neck resistance with a slow time constant (that doesn't change without calcium)
a voltage-dependent calcium conductance (back-prop evoked calcium)

Mark Goldman's idea: 
A secondary signal (e.g. exceptional voltage? or more like dopamine) that acts as a bifurcation parameter which 
takes an otherwise "normal" depolarization --> LTP event and makes it depolarization --> super LTP
---> this is a signal that will push the spine neck resistance into it's bistable regime

==========================================

tau dR/dt = f(R, Ca, Da?)
dR/dt = 0 if Ca = 0
dR/dt < 0 if 0 < Ca < theta_ca
dR/dt > 0 if theta_ca < Ca
"""




In [29]:
import torch
import snntorch as snn
from snntorch import surrogate
import torch.nn as nn
import torch.optim as optim

# Define the network
class SpikingNet(nn.Module):
    def __init__(self, num_inputs, num_hidden, num_outputs):
        super().__init__()
        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=0.9)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=0.9)

    def forward(self, x, num_steps):
        # Initialize hidden states and outputs
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        spk2_rec = []
        mem2_rec = []

        for _ in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

# Set up the network and training parameters
num_inputs = 5
num_hidden = 8
num_outputs = 2
num_steps = 25
batch_size = 32

net = SpikingNet(num_inputs, num_hidden, num_outputs)
optimizer = optim.Adam(net.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

# Generate some dummy data
def generate_data(num_samples):
    x = torch.randint(0, 2, (num_samples, num_inputs)).float()
    y = (x.sum(dim=1) > num_inputs // 2).long()
    return x, y

# Training loop
num_epochs = 500
for epoch in range(num_epochs):
    x, y = generate_data(batch_size)
    spk, mem = net(x, num_steps)
    
    # Calculate cross entropy loss
    loss = loss_fn(mem[-1], y)
    
    # Gradient calculation and weight update
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# Test the network
x_test, y_test = generate_data(100)
spk, mem = net(x_test, num_steps)
_, pred = mem[-1].max(1)
accuracy = (pred == y_test).float().mean()
print(f"Test accuracy: {accuracy.item():.4f}")