In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Modified ReducedNANLayer
class ReducedNANLayer(nn.Module):
    """
    Reduced neuron-astrocyte network without symmetry assumptions.
    """
    def __init__(self, input_size, dt, steps, device, clamp=False, store_intermediate=False):
        super(ReducedNANLayer, self).__init__()

        self.input_size = input_size
        self.dt = dt
        self.steps = steps
        self.device = device
        self.clamp = clamp
        self.store_intermediate = store_intermediate

        # Activation functions
        self.phi = F.relu  # neuron activation
        self.g = F.relu    # synapse activation
        self.psi = F.relu  # process activation

        # Time constants and factors for Euler integration
        self.tau_x, self.tau_s, self.tau_p = 1.0, 1.0, 1.0
        self.alpha_x = self.dt / self.tau_x
        self.alpha_s = self.dt / self.tau_s
        self.alpha_p = self.dt / self.tau_p

        # Network weights and biases
        self.W_xx = nn.Linear(input_size, input_size, bias=False, device=device)  # neuron-to-neuron
        self.W_xs = nn.Linear(input_size, input_size, bias=False, device=device)  # neuron-to-synapse
        self.W_sp = nn.Parameter(torch.empty(input_size, device=device)) # synapse-to-process
        self.W_ps = nn.Parameter(torch.empty(input_size, device=device)) # process-to-synapse
        nn.init.kaiming_normal_(self.W_sp.unsqueeze(1))
        nn.init.kaiming_normal_(self.W_ps.unsqueeze(1))
        self.W_pp = nn.Linear(input_size, input_size, bias=False, device=device)  # process-to-process

    def forward(self, inp, x0=None, s0=None, p0=None, free_inds=None):

        # Initialize states if not provided
        if x0 is None:
            x0 = torch.zeros_like(inp, device=self.device)
        if s0 is None:
            s0 = torch.zeros_like(inp, device=self.device)
        if p0 is None:
            p0 = torch.zeros_like(inp, device=self.device)

        x, s, p = x0.clone(), s0.clone(), p0.clone()
        xs = []

        for _ in range(self.steps):
            # Activations
            phi_t = self.phi(x)
            g_t = self.g(s)
            psi_t = self.psi(p)

            # State updates
            dx = -x + self.W_xx(g_t * phi_t) + inp
            ds = -s + self.W_ps * psi_t + self.W_xs(phi_t)
            dp = -p + self.W_pp(psi_t) + self.W_sp * g_t

            # Euler integration with optional clamping
            x = x + self.alpha_x * dx * (free_inds if free_inds is not None else 1)
            s = s + self.alpha_s * ds
            p = p + self.alpha_p * dp

            if self.store_intermediate:
                xs.append(x)

        xs = torch.stack(xs) if self.store_intermediate else None
        return x, xs

# Modified ReducedNAN
class ReducedNAN(nn.Module):
    """
    A module that applies a read-in layer, then the ReducedNAN model, and finally a linear readout layer.
    """
    def __init__(self, input_size, hidden_size, output_size, dt, steps, device, clamp=False, store_intermediate=False):
        super(ReducedNAN, self).__init__()

        # Read-in layer to transform input to hidden dimension
        self.read_in = nn.Linear(input_size, hidden_size, device=device)

        # Instantiate the ReducedNAN model
        self.reduced_nan = ReducedNANLayer(hidden_size, dt, steps, device, clamp, store_intermediate)

        # Output layer that maps the final state to the desired output size
        self.readout = nn.Linear(hidden_size, output_size, device=device)

    def forward(self, inp, x0=None, s0=None, p0=None, free_inds=None):
        # Apply read-in layer
        inp_transformed = self.read_in(inp)

        # Pass through ReducedNAN model
        final_state, intermediate_states = self.reduced_nan(inp_transformed, x0, s0, p0, free_inds)

        # Apply readout layer to the final output
        output = self.readout(final_state)

        return output, intermediate_states

# Training code
def train_mnist():
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Using device:", device)

    # Hyperparameters
    input_size = 28 * 28  # MNIST images are 28x28
    hidden_size = 128
    output_size = 10  # 10 classes for MNIST digits
    dt = 0.1
    steps = 10
    clamp = False
    store_intermediate = False
    num_epochs = 5
    batch_size = 64
    learning_rate = 0.001

    # MNIST dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    test_dataset  = datasets.MNIST(root='./data', train=False, transform=transform)

    # Data loaders
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    test_loader  = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

    # Model, loss function, optimizer
    model = ReducedNAN(input_size, hidden_size, output_size, dt, steps, device, clamp, store_intermediate).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Training loop
    total_step = len(train_loader)
    for epoch in range(num_epochs):
        model.train()
        for batch_idx, (images, labels) in enumerate(train_loader):
            # Flatten images to [batch_size, input_size]
            images = images.view(-1, 28*28).to(device)
            labels = labels.to(device)

            # Forward pass
            outputs, _ = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (batch_idx+1) % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{total_step}], Loss: {loss.item():.4f}')

    # Testing the model
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.view(-1, 28*28).to(device)
            labels = labels.to(device)
            outputs, _ = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print(f'Test Accuracy of the model on the 10000 test images: {100 * correct / total} %')

if __name__ == "__main__":
    train_mnist()


Using device: cuda
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:01<00:00, 5.08MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 133kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:06<00:00, 241kB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 4.52MB/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Epoch [1/5], Step [100/938], Loss: 0.5124
Epoch [1/5], Step [200/938], Loss: 0.2797
Epoch [1/5], Step [300/938], Loss: 0.2140
Epoch [1/5], Step [400/938], Loss: 0.1360
Epoch [1/5], Step [500/938], Loss: 0.1659
Epoch [1/5], Step [600/938], Loss: 0.1865
Epoch [1/5], Step [700/938], Loss: 0.0621
Epoch [1/5], Step [800/938], Loss: 0.3131
Epoch [1/5], Step [900/938], Loss: 0.2631
Epoch [2/5], Step [100/938], Loss: 0.1227
Epoch [2/5], Step [200/938], Loss: 0.0797
Epoch [2/5], Step [300/938], Loss: 0.0770
Epoch [2/5], Step [400/938], Loss: 0.0355
Epoch [2/5], Step [500/938], Loss: 0.2166
Epoch [2/5], Step [600/938], Loss: 0.0372
Epoch [2/5], Step [700/938], Loss: 0.1248
Epoch [2/5], Step [800/938], Loss: 0.1235
Epoch [2/5], Step [900/938], Loss: 0.2347
Epoch [3/5], Step [100/938], Loss: 0.1287
Epoch [3/5], Step [200/938], Loss: 0.0401
Epoch [3/5], Step [300/938], Loss: 0.2082
Epoch [3/5], Step [400/938], Loss: 0.1907
E