# Neuromorphic Software: Phase-Encoding

What it is
What it does
Explain fixed bio encoder + learnable MultiLayerPerceptron for Readout (classifier)


# Phase Encoding - The Maths
Here's a LaTeX representation of the mathematical function implemented in `encode_image()`:

## Input:

- $I$: Flattened image (a vector of pixel values)
- $\omega_{active}$: Active frequency parameter
- $x$: Spatial layout parameter (a vector)

### Parameters:

- $\theta_{thresh}$: Threshold phase (set to 0.0 in the code). When a Neuron's phase has rotated through $2\pi$ to $0$ we consider the Neuron has generated a Spike.
- $\omega_{ref}$: Reference frequency
- $n$ A constant (set to 4.0 in the code)
- $\kappa$: A constant (set to $2\pi$ in the code)

## Calculations:

#### Initial Phase: $$\theta_{init} = I \cdot 2\pi$$

#### Phase Difference: $$\Delta\theta = (\theta_{thresh} - \theta_{init} + 2\pi) \pmod{2\pi}$$

#### Spike Time: $$t_{spike} = \frac{\Delta\theta}{\omega_{active}}$$

#### Reference Phase: $$\theta_{ref} = \left(\frac{\omega_{ref} \cdot t_{spike} + \kappa \cdot x}{n}\right) \pmod{2\pi}$$

#### Final Phase Difference: $$\phi = (\theta_{thresh} - \theta_{ref} + 2\pi) \pmod{2\pi}$$

## Output:

- Encoded image: $[\cos(\phi), \sin(\phi)]$ (a concatenated vector of cosine and sine of the final phase difference)

### Formal Definition:

$$ \text{encode\_image}(I, \omega_{active}, x) = [\cos(\phi), \sin(\phi)] $$

where

$$ \phi = (\theta_{thresh} - \theta_{ref} + 2\pi) \pmod{2\pi} $$

and $\theta_{ref}$ is calculated as described in the steps above.

This definition encapsulates the mathematical operations performed by the `encode_image()` function, providing a concise and formal representation of the phase encoding process.


In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

# --- Device Configuration ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- 1. Load and Normalize MNIST ---
transform_train = transforms.Compose([
    transforms.RandomRotation(35),  # Rotate images by up to 25 degrees
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # Translate by up to 10%
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # Mean and std deviation for MNIST
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # Mean and std deviation for MNIST
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)

# --- Split training data into training and validation sets ---
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

# --- 2. Dataloaders ---
batch_size = 100  # Increase batch size
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# --- 3. Phase Encoding Parameters ---
N = 28 * 28
omega_active = torch.ones(N, dtype=torch.float32) * 2 * np.pi * 20.0
theta_thresh = 0.0
omega_ref = 2 * np.pi * 8.0
n = 4.0
kappa = 2 * np.pi
x = torch.linspace(0, 1, N, dtype=torch.float32)


# --- 4. Phase Encoder Function ---
def encode_image(img_flat, omega_active_param, x_param):
    theta_init = (img_flat * 2 * np.pi)
    delta_theta = torch.fmod(theta_thresh - theta_init + 2 * np.pi, 2 * np.pi)
    t_spike = delta_theta / omega_active_param
    theta_ref = torch.fmod((omega_ref * t_spike + kappa * x_param) / n, 2 * np.pi)
    phase_diff = torch.fmod(theta_thresh - theta_ref + 2 * np.pi, 2 * np.pi)
    return torch.cat([torch.cos(phase_diff), torch.sin(phase_diff)])


# --- 5. Define Classifier ---
class PhaseClassifier(nn.Module):
    def __init__(self, num_classes):
        super(PhaseClassifier, self).__init__()
        self.layer1 = nn.Linear(N * 2, 784)
        self.layer2 = nn.Linear(784, 512)
        self.layer3 = nn.Linear(512, num_classes)
        self.dropout = nn.Dropout(0.5)  # Add dropout layer
        self.relu = nn.ReLU()

    def forward(self, x, omega_active, spatial_layout):
        encoded_images = []
        for img in x:
            img_flat = img.view(-1)  # Flatten the image
            encoded_img = encode_image(img_flat, omega_active, spatial_layout)
            encoded_images.append(encoded_img)
        encoded_images = torch.stack(encoded_images)
        encoded_images = self.dropout(encoded_images)  # Apply dropout
        x = self.relu(self.layer1(encoded_images))
        x = self.relu(self.layer2(x))
        logits = self.layer3(x)
        return logits


# --- 6. Training Loop ---
num_classes = 10
model = PhaseClassifier(num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# --- Early Stopping Parameters ---
best_val_loss = float('inf')
patience = 3
counter = 0

num_epochs = 30  # Train for more epochs
for epoch in range(num_epochs):
    model.train()
    loop = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}", leave=False)
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images, omega_active.to(device), x.to(device))
        loss = criterion(outputs, labels)

        # Backward and optimize
        loss.backward()
        optimizer.step()

        # Update tqdm loop
        loop.set_postfix(loss=loss.item())

    # --- Validation ---
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images, omega_active.to(device), x.to(device))
            loss = criterion(outputs, labels)
            val_loss += loss.item()

    val_loss /= len(val_loader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.4f}, Validation Loss: {val_loss:.4f}")

    # --- Early Stopping ---
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping triggered!")
            break

# --- 7. Evaluation ---
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images, omega_active.to(device), x.to(device))
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Accuracy on the test set: {accuracy:.2f}%")

Using device: cpu
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:06<00:00, 1608207.05it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 257774.46it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:02<00:00, 703103.23it/s] 


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2265223.40it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
                                                                        

Epoch 1/30, Loss: 1.4843, Validation Loss: 1.4621


                                                                        

Epoch 2/30, Loss: 1.2792, Validation Loss: 1.3284


                                                                        

Epoch 3/30, Loss: 1.0391, Validation Loss: 1.0601


                                                                        

Epoch 4/30, Loss: 0.9410, Validation Loss: 0.9420


                                                                         

Epoch 5/30, Loss: 0.9094, Validation Loss: 0.8504


                                                                         

Epoch 6/30, Loss: 0.7878, Validation Loss: 0.7932


                                                                         

Epoch 7/30, Loss: 0.8226, Validation Loss: 0.7014


                                                                         

Epoch 8/30, Loss: 0.7445, Validation Loss: 0.6780


                                                                         

Epoch 9/30, Loss: 0.6991, Validation Loss: 0.6404


                                                                          

Epoch 10/30, Loss: 0.7464, Validation Loss: 0.6339


                                                                          

Epoch 11/30, Loss: 0.5971, Validation Loss: 0.5584


                                                                          

Epoch 12/30, Loss: 0.5966, Validation Loss: 0.5421


                                                                          

Epoch 13/30, Loss: 0.7167, Validation Loss: 0.5464


                                                                          

Epoch 14/30, Loss: 0.6808, Validation Loss: 0.5079


                                                                          

Epoch 15/30, Loss: 0.6042, Validation Loss: 0.4925


                                                                          

Epoch 16/30, Loss: 0.6795, Validation Loss: 0.5230


                                                                          

Epoch 17/30, Loss: 0.5036, Validation Loss: 0.4747


                                                                          

Epoch 18/30, Loss: 0.6183, Validation Loss: 0.5004


                                                                          

Epoch 19/30, Loss: 0.4274, Validation Loss: 0.4272


                                                                          

Epoch 20/30, Loss: 0.5509, Validation Loss: 0.4650


                                                                          

Epoch 21/30, Loss: 0.6002, Validation Loss: 0.4874


                                                                          

Epoch 22/30, Loss: 0.5609, Validation Loss: 0.4358
Early stopping triggered!
Accuracy on the test set: 92.10%


## Ablation Study:

We'll create a modified version of the code where we remove the `t_spike` calculation and directly encode the pixel values using cosine and sine.

Comparing the performance of this modified version to our original model should further evidence the importance of the first spike time calculation.

In [None]:
def direct_encode(img_flat, encoding_dim):
    """
    Directly encodes pixel values using cosine and sine functions.

    Args:
      img_flat: A flattened PyTorch tensor representing the image pixels.
      encoding_dim: The dimensionality of the encoding.

    Returns:
      A PyTorch tensor containing the encoded representation.
    """

    # Create a range of values for the encoding dimensions
    encoding_indices = torch.arange(encoding_dim, dtype=torch.float32, device=img_flat.device)

    # Scale pixel values to the range [0, 2*pi]
    scaled_pixels = img_flat * 2 * torch.pi

    # Calculate cosine and sine encodings
    cos_encoding = torch.cos(scaled_pixels[:, None] * encoding_indices[None, :])
    sin_encoding = torch.sin(scaled_pixels[:, None] * encoding_indices[None, :])

    # Concatenate cosine and sine encodings
    encoded_representation = torch.cat([cos_encoding, sin_encoding], dim=-1)

    return encoded_representation