# Neuromorphic Software: Phase-Encoding

What it is
What it does
Explain fixed bio encoder + learnable MultiLayerPerceptron for Readout (classifier)


## Notebook setup

Imports, Random Seeding, Device Setup - Training will take about 20 minutes if running on CPU.

In [3]:
import random
from typing import Callable

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

In [4]:
# to ensure reproducibility, we should set the random seed consistently
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)  # for multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

In [5]:
# --- Device Configuration ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


## Experiment Setup

In [6]:
# --- 0.0 Define Experiment Parameters ---
NUM_CLASSES = 10  # 10 digits in MNIST Dataset
MAX_ROTATION_DEGREES = 35  # Rotate images by up to 35 degrees
MAX_TRANSLATION = 0.1  # Translate images by up to 10%

# --- 0.1 Define Classifier Model Parameters ---
L1_OUTPUT_DIMS = 784
L2_OUTPUT_DIMS = 512
DROPOUT_PROB = 0.5  # Randomly "mute" a proportion of input neurons

# --- 0.2 Define Training Parameters ---
NUM_EPOCHS = 20
BATCH_SIZE = 64
LEARNING_RATE = 0.001
PATIENCE = 3

## Loading and Preprocessing MNIST Data
- Defines image transformations to be applied to the training and testing data, respectively.
- The training data is augmented with random rotations and translations.
- Both datasets are converted to PyTorch tensors and normalized.
- Loads the MNIST dataset.
- The root argument specifies where to store the data, train=True indicates the training set, and download=True downloads the data if it's not already present.
- Splits the training data into training and validation sets.
- This is important to evaluate the model's performance during training and prevent overfitting.
- Creates data loaders for the training, validation, and testing sets.
- These loaders handle batching and shuffling of the data, making it easier to feed into the model during training and evaluation.

In [7]:
# --- 1. Load and Normalize MNIST ---
transform_train = transforms.Compose([
    transforms.RandomRotation(MAX_ROTATION_DEGREES),
    transforms.RandomAffine(degrees=0, translate=(MAX_TRANSLATION, MAX_TRANSLATION)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # Mean and std deviation for MNIST
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # Mean and std deviation for MNIST
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)
# --- Split training data into training and validation sets ---
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

# --- 2. Dataloaders ---
batch_size = BATCH_SIZE
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## Phase Encoding Setup

#### Phase Encoding Parameters:
- These variables (`N`, `omega_active`, `theta_thresh`, `omega_ref`, `n`, `kappa, x`) are parameters used in the phase encoding process.
- They control aspects of how the image data is transformed into a phase-encoded representation.

#### Input:

- $I$: Flattened image (a vector of pixel values)
- $N$: The X and Y dimensions of the image
- $\omega_{active}$: Active frequency parameter
- $x$: Spatial layout parameter (a vector) initialized to a range of values representing a phase gradient across a field of sensory neurons.

#### Parameters:

- $\theta_{thresh}$: Threshold phase (set to `0.0` in the code). When a Neuron's phase has rotated through $2\pi$ to $0$ we consider the Neuron has generated a Spike.
- $\omega_{ref}$: Reference frequency (set to `20 Hz` in the code).
- $n$ A constant (set to 4.0 in the code). TODO: What is this?
- $\kappa$: A constant (set to $2\pi$ in the code). TODO: What is this?

In [8]:
# --- 3. Phase Encoding Parameters ---
N = 28 * 28
omega_active = torch.ones(N, dtype=torch.float32) * 2 * np.pi * 20.0
theta_thresh = 0.0
omega_ref = 2 * np.pi * 8.0
n = 4.0
kappa = 2 * np.pi
x = torch.linspace(0, 1, N, dtype=torch.float32)  # TODO: Check if this is overwritten in torch.NN code?

EncoderFnType = Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]

### Phase Encoder Function

#### Calculation:
- Initial Phase: $$\theta_{init} = I \cdot 2\pi$$
- Phase Difference: $$\Delta\theta = (\theta_{thresh} - \theta_{init} + 2\pi) \pmod{2\pi}$$
- Spike Time: $$t_{spike} = \frac{\Delta\theta}{\omega_{active}}$$
- Reference Phase: $$\theta_{ref} = \left(\frac{\omega_{ref} \cdot t_{spike} + \kappa \cdot x}{n}\right) \pmod{2\pi}$$
- Final Phase Difference: $$\phi = (\theta_{thresh} - \theta_{ref} + 2\pi) \pmod{2\pi}$$

#### Output:
- Encoded image (a concatenated vector of cosine and sine of the final phase difference):  $$[\cos(\phi), \sin(\phi)]$$

### Formal Definition:
$$ \text{encode\_image}(I, \omega_{active}, x) = [\cos(\phi), \sin(\phi)] $$
where
$$ \phi = (\theta_{thresh} - \theta_{ref} + 2\pi) \pmod{2\pi} $$
and $\theta_{ref}$ is calculated as described in the steps above.

In [9]:
# --- 4. Phase Encoder Function ---
def encode_image(img_flat: torch.Tensor,
                 omega_active_param: torch.Tensor,
                 x_param: torch.Tensor) -> torch.Tensor:
    theta_init = (img_flat * 2 * np.pi)
    delta_theta = torch.fmod(theta_thresh - theta_init + 2 * np.pi, 2 * np.pi)
    t_spike = delta_theta / omega_active_param
    theta_ref = torch.fmod((omega_ref * t_spike + kappa * x_param) / n, 2 * np.pi)
    phase_diff = torch.fmod(theta_thresh - theta_ref + 2 * np.pi, 2 * np.pi)
    return torch.cat([torch.cos(phase_diff), torch.sin(phase_diff)])

In [10]:
# --- 5. Define Classifier ---
class PhaseClassifier(nn.Module):
    def __init__(self, num_classes: int):
        super(PhaseClassifier, self).__init__()
        self.layer1 = nn.Linear(N * 2, L1_OUTPUT_DIMS)
        self.layer2 = nn.Linear(L1_OUTPUT_DIMS, L2_OUTPUT_DIMS)
        self.layer3 = nn.Linear(L2_OUTPUT_DIMS, num_classes)
        self.dropout = nn.Dropout(DROPOUT_PROB)
        self.relu = nn.ReLU()

    def forward(self,
                batch: torch.Tensor,
                omega_active: torch.Tensor,
                spatial_layout: torch.Tensor,
                encoder_fn: EncoderFnType = encode_image
                ):
        encoded_images: list[torch.Tensor] = []
        for img in batch:
            img_flat = img.view(-1)  # Flatten the image
            encoded_img = encoder_fn(
                img_flat,
                omega_active,
                spatial_layout)
            encoded_images.append(encoded_img)
        encoded_images_t = torch.stack(encoded_images)
        encoded_images_t = self.dropout(encoded_images_t)
        _x = self.relu(self.layer1(encoded_images_t))
        _x = self.relu(self.layer2(_x))
        logits = self.layer3(_x)
        return logits

## Training Setup
- Blah

In [24]:
# --- Visualising Training Loss Over Epochs ---
import plotly.express as px
import pandas as pd
from IPython.display import display, clear_output

# Create an empty DataFrame to hold epoch and loss values
df = pd.DataFrame({'Epoch': [], 'Loss': [], 'Validation Loss': []})


# Function to update the Plotly Express figure
def update_training_plot(epoch, loss, validation_loss):
    global df
    df = df.append({'Epoch': epoch, 'Loss': loss, 'Validation Loss': validation_loss}, ignore_index=True)
    fig = px.line(df, x='Epoch', y=['Loss', 'Validation Loss'], title='Training Losses Over Epochs')
    clear_output(wait=True)  # Clear previous output for real-time update effect
    display(fig)

In [26]:
# --- 6. Training Loop ---
def run_training(train_loader: DataLoader,
                 val_loader: DataLoader,
                 encoder_fn: EncoderFnType = encode_image) -> nn.Module:
    num_classes = NUM_CLASSES
    model = PhaseClassifier(num_classes).to(device)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    # --- Early Stopping Parameters ---
    best_val_loss = float('inf')
    patience = PATIENCE
    counter = 0

    num_epochs = NUM_EPOCHS
    loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        loop = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}", leave=False)
        for images, labels in loop:
            images, labels = images.to(device), labels.to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images, omega_active.to(device), x.to(device), encoder_fn)
            loss = criterion(outputs, labels)

            # Backward and optimize
            loss.backward()
            optimizer.step()

            # Update tqdm loop
            loop.set_postfix(loss=loss.item())

        # --- Validation ---
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images, omega_active.to(device), x.to(device))
                loss = criterion(outputs, labels)
                val_loss += loss.item()

        val_loss /= len(val_loader)
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.4f}, Validation Loss: {val_loss:.4f}")
        update_training_plot(epoch + 1, loss.item(), val_loss)

        # --- Early Stopping ---
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print("Early stopping triggered!")
                break

    return model

In [12]:
# --- 7. Evaluation ---
def run_evaluation(model: nn.Module,
                   test_loader: DataLoader,
                   encoder_fn: EncoderFnType = encode_image):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs: torch.Tensor = model(images, omega_active.to(device), x.to(device), encoder_fn)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Accuracy on the test set: {accuracy:.2f}%")

## Training and Testing the Model
Let's hook it all up!

In [13]:
model = run_training(train_loader, val_loader, encode_image)
run_evaluation(model, test_loader, encode_image)

                                                                        

Epoch 1/20, Loss: 1.5342, Validation Loss: 1.5002


                                                                        

Epoch 2/20, Loss: 1.3446, Validation Loss: 1.2470


                                                                         

Epoch 3/20, Loss: 1.1018, Validation Loss: 1.0600


                                                                         

Epoch 4/20, Loss: 0.8272, Validation Loss: 0.8868


                                                                         

Epoch 5/20, Loss: 0.6571, Validation Loss: 0.7906


                                                                         

Epoch 6/20, Loss: 0.7485, Validation Loss: 0.7417


                                                                         

Epoch 7/20, Loss: 0.7817, Validation Loss: 0.6953


                                                                         

Epoch 8/20, Loss: 0.6116, Validation Loss: 0.6440


                                                                         

Epoch 9/20, Loss: 0.5416, Validation Loss: 0.6144


                                                                          

Epoch 10/20, Loss: 0.5430, Validation Loss: 0.6002


                                                                          

Epoch 11/20, Loss: 0.4946, Validation Loss: 0.6088


                                                                          

Epoch 12/20, Loss: 0.5523, Validation Loss: 0.5896


                                                                          

Epoch 13/20, Loss: 0.6199, Validation Loss: 0.5807


                                                                          

Epoch 14/20, Loss: 0.5869, Validation Loss: 0.5998


                                                                          

Epoch 15/20, Loss: 0.5164, Validation Loss: 0.5799


                                                                          

Epoch 16/20, Loss: 0.6186, Validation Loss: 0.5466


                                                                          

Epoch 17/20, Loss: 0.3927, Validation Loss: 0.5391


                                                                          

Epoch 18/20, Loss: 0.3869, Validation Loss: 0.4756


                                                                          

Epoch 19/20, Loss: 0.4707, Validation Loss: 0.5099


                                                                          

Epoch 20/20, Loss: 0.3311, Validation Loss: 0.5051
Accuracy on the test set: 90.75%


## Ablation Study:

We'll create a modified version of the code where we remove the `t_spike` calculation and directly encode the pixel values using cosine and sine.

Comparing the performance of this modified version to our original model should further evidence the importance of the first spike time calculation.

In [21]:
def direct_encode(img_flat: torch.Tensor,
                  omega_active: torch.Tensor,
                  spatial_layout: torch.Tensor) -> torch.Tensor:
    """
    Directly encodes pixel values using cosine and sine functions.
    
    This implementation returns the same dimensions as encode_image calculated
    as [cos(scaled_pixels), sin(scaled_pixels)] where scaled_pixels is in the range [0, 2*pi].

    Args:
      img_flat: A flattened PyTorch tensor representing the image pixels.
      omega_active: Not used in this direct method.
      spatial_layout: Not used in this direct method.

    Returns:
      A PyTorch tensor with shape (2*N,) corresponding to the encoded image.
    """
    # Scale pixel values to the range [0, 2*pi]
    scaled_pixels = img_flat * 2 * torch.pi

    # Compute cosine and sine directly for each pixel
    cos_encoding = torch.cos(scaled_pixels)
    sin_encoding = torch.sin(scaled_pixels)

    # Concatenate to produce an encoding with 2N dimensions
    encoded_representation = torch.cat([cos_encoding, sin_encoding])

    return encoded_representation

## Training and Testing the Default Encoder


In [None]:
from functools import partial

wrapped_default_encode = partial(direct_encode)
model = run_training(train_loader, val_loader, wrapped_default_encode)
run_evaluation(model, test_loader, wrapped_default_encode)

Epoch 1/20:  83%|████████▎ | 620/750 [00:39<00:08, 14.88it/s, loss=1.62]