# Oscillator Field Dynamics


## Notebook setup

Imports, Random Seeding, Device Setup - Training time varies based on hardware: approximately 20 minutes on CPU, and significantly faster on GPU.


In [None]:
import random
from typing import Callable

import numpy as np
import pandas as pd
import plotly.express as px
import torch
import torch.nn as nn
import torch.optim as optim
from IPython import display
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

# to ensure reproducibility, we should set the random seed consistently
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)  # for multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

# mathematical constants
EPSILON: float = 1e-12 # for safe division
TWO_PI: 2 * np.pi

### Device Setup
Will output `Using device: cuda` if running on GPU.
If `Using device: cpu` will take significantly longer.

In [None]:
# --- Device Configuration ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

## Experiment Setup
Description goes here..


In [None]:
# --- 0.0 Define Experiment Parameters ---
NUM_CLASSES = 10  # 10 digits in MNIST Dataset
MAX_ROTATION_DEGREES = 35  # Rotate images by up to 35 degrees
MAX_TRANSLATION = 0.1  # Translate images by up to 10%

# --- 0.1 Define Oscillator Field Parameters ---
DELTA_T = 0.01    # Timestep delta in Seconds
TIMESTEPS = 512  # TIMESTEPS x DELTA_T = Simulation Run-time

# --- 0.2 Define Training Parameters ---
NUM_EPOCHS = 20
BATCH_SIZE = 64
LEARNING_RATE = 0.001
PATIENCE = 3


## Loading and Preprocessing MNIST Data
- Defines image transformations to be applied to the training and testing data, respectively.
- The training data is augmented with random rotations and translations.
- Both datasets are converted to PyTorch tensors and normalized.
- Normalization is crucial as it helps neural networks converge faster and perform better by ensuring all input features are on a similar scale and centered around zero.
- Loads the MNIST dataset.
- The root argument specifies where to store the data, train=True indicates the training set, and download=True downloads the data if it's not already present.
- Splits the training data into training and validation sets.
- This is important to evaluate the model's performance during training and prevent overfitting.
- Creates data loaders for the training, validation, and testing sets.
- These loaders handle batching and shuffling of the data, making it easier to feed into the model during training and evaluation.


In [None]:
# Mean and std deviation for MNIST
MNIST_MEAN = 0.1307
MNIST_STD = 0.3081
transform_normalize_to_mnist_mean_and_std = transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))

# --- 1. Load and Normalize MNIST ---
transform_train = transforms.Compose([
    transforms.RandomRotation(MAX_ROTATION_DEGREES),
    transforms.RandomAffine(degrees=0, translate=(MAX_TRANSLATION, MAX_TRANSLATION)),
    transforms.ToTensor(),
    transform_normalize_to_mnist_mean_and_std
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transform_normalize_to_mnist_mean_and_std
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)

# --- Split training data into training and validation sets ---
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

# --- 2. Dataloaders ---
batch_size = BATCH_SIZE
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## MNIST DEMO
For now, this grabs a single digit from the MNIST dataset to demonstrate the field dynamics.

In [None]:
# Load MNIST digit
DIGIT = 3  # Using integer instead of string for PyTorch
INVERT_POLARITY = False
H, W, D = 28, 28, 4  # Original MNIST dimensions

# Load single digit from MNIST 
transform = transforms.Compose([
    transforms.ToTensor(),
    transform_normalize_to_mnist_mean_and_std
])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
digit_idx = next(idx for idx, (_, label) in enumerate(dataset) if label == DIGIT)
the_digit = dataset[digit_idx][0].squeeze().numpy()  # Get first instance of digit 6

# Scale to 0-2π and invert polarity
the_digit = (the_digit * 2 * np.pi)  # Scale 0-1 to 0-2π
the_digit = 2 * np.pi - the_digit if INVERT_POLARITY else the_digit

# Create perturbation (28x28x4) as numpy array
perturbation = np.zeros((H, W, D), dtype=np.float32)
perturbation[:, :, 0] = the_digit
for d in range(1, D): # copy the perturbation across all dimensions
    perturbation[:, :, d] = perturbation[:, :, 0]

## Kuramoto Oscillator Model
Description goes here..

In [None]:
import torch
import torch.nn.functional as F
import math

class KuramotoVectorOscillatorField:
    def __init__(
        self,
        height: int,
        width: int,
        dims: int = 4,
        delta_t: float = 0.01,
        k_coupling: float = 0.5,
        k_omega: float = 1.0,
        k_bias: float = 0.3,
        spike_thresh: float = 0.01,
        initial_phases: torch.Tensor | None = None,
        device: torch.device = torch.device("cpu"),
    ) -> None:
        self.H, self.W, self.D = height, width, dims
        self.delta_t = delta_t
        self.k_coupling = k_coupling
        self.k_omega = k_omega
        self.k_bias = k_bias
        self.spike_thresh = spike_thresh
        self.device = device

        # Initialize phases if not provided and move to the specified device.
        if initial_phases is None:
            initial_phases = torch.rand(self.H, self.W, self.D, device=self.device) * 2 * math.pi

        # Oscillator state: unit complex vectors (H, W, D)
        # PyTorch supports complex tensors as of version 1.6.
        self.z = torch.exp(1j * initial_phases)
        self.z_prev = self.z.clone()

        # External bias (input) - initialised to initial phase field
        self.c = self.z.clone()

        # Skew-symmetric omega for cross-dimensional coupling (D, D)
        # Create a random matrix and subtract its transpose.
        omega_init = np.random.rand(self.D, self.D)
        omega_init = omega_init - omega_init.T
        #self.omega = self.normalize(omega_init)
        self.omega = torch.tensor(omega_init, dtype=torch.complex64, device=self.device)


    def neighbor_sum(self, z: torch.Tensor) -> torch.Tensor:
        # Add a padding of 1 along each dimension.
        # For 3D tensor, pad order for F.pad is (last dim start, last dim end, second last dim start, second last dim end, first dim start, first dim end)
        z_padded = F.pad(z, (1, 1, 1, 1, 1, 1), mode="constant", value=0)

        # Six-neighbor summation
        # Note that z has shape (H, W, D) and we sum neighbors along each axis.
        neighbors = (
            z_padded[1:-1, 1:-1, :-2] +    # Negative direction along D
            z_padded[1:-1, 1:-1, 2:] +     # Positive direction along D
            z_padded[1:-1, :-2, 1:-1] +    # Negative direction along W
            z_padded[1:-1, 2:, 1:-1] +     # Positive direction along W
            z_padded[:-2, 1:-1, 1:-1] +    # Negative direction along H
            z_padded[2:, 1:-1, 1:-1]       # Positive direction along H
        )
        return neighbors

    def normalize(self, z: torch.Tensor) -> torch.Tensor:
        # Normalize each complex number to have unit magnitude
        return z / (torch.abs(z) + EPSILON)

    def step(self, evolve_c: bool = False) -> torch.Tensor:
        self.z_prev = self.z.clone()

        # 1. Local spatial coupling
        z_neighbors = self.neighbor_sum(self.z)

        # 2. Cross-dimensional omega interaction
        # Using torch.einsum, note that self.z is of shape (H, W, D)
        z_omega = torch.einsum("ij,hwj->hwi", self.omega, self.z)

        # 3. Input bias direction
        bias = self.c - self.z

        # 4. Total update
        delta_z = (
            self.k_coupling * (z_neighbors - self.z)
            + self.k_omega * z_omega
            + self.k_bias * bias
        )

        # 5. Euler update + normalize
        self.z = self.normalize(self.z + self.delta_t * delta_z)

        # 6. Compute spikes (phase velocity)
        # Compute the phase difference between current and previous steps.
        # torch.angle returns the phase of a complex tensor.
        delta_theta = torch.angle(self.z * torch.conj(self.z_prev))
        spike_activity = (torch.abs(delta_theta) > self.spike_thresh).float()

        # 7. Evolve external input field if requested
        if evolve_c:
            delta_c = 0.1 * (self.z - self.c)
            self.c = self.normalize(self.c + self.delta_t * delta_c)

        return spike_activity

    def run(
        self, steps: int = 100, evolve_c: bool = False, return_history: bool = False
    ) -> tuple[torch.Tensor, torch.Tensor] | tuple[list[torch.Tensor], list[torch.Tensor]]:
        z_history_list = []
        spike_history_list = []
        spikes_generated = torch.zeros((self.H, self.W, self.D), device=self.device)
        for _ in range(steps):
            spikes_generated = self.step(evolve_c=evolve_c)
            if return_history:
                z_history_list.append(self.z.clone())
                spike_history_list.append(spikes_generated.clone())
        if return_history:
            return z_history_list, spike_history_list
        else:
            return self.z, spikes_generated

## Run the Physics Simulation

In [None]:
# Initialize field
field = KuramotoVectorOscillatorField(
    height=H,
    width=W,
    dims=D,
    delta_t=DELTA_T,
    k_coupling=0.12,
    k_omega=0.12,
    k_bias=1.25,
    spike_thresh=0.00035,
    device=DEVICE
)

# Introduce perturbation as a symmetry-breaking condition
field.c = torch.exp(1j * torch.from_numpy(perturbation).to(DEVICE))

# Run for (default:2048) steps
z_history, spike_history = field.run(steps=TIMESTEPS, evolve_c=True, return_history=True)

## Visualise Phase Space Trajectory

This visualization shows how the phase angles of oscillators evolve over time in a reduced 2D space using PCA. The plot reveals:

- The collective motion and synchronization patterns of oscillator populations
- Temporal evolution of phase relationships between oscillators 
- Global dynamics and stability of the system as it evolves from initial conditions
- Potential attractor states or limit cycles in the phase space
- Clustering behavior of oscillators with similar phases


In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
from sklearn.decomposition import PCA

def generate_phase_plot(Z, labels, title="Phase Space Trajectory", use_tsne=False):
    """
    Z: np.ndarray of shape [T, N, D] where
       T = timesteps or layers,
       N = number of samples,
       D = phase-feature dimension (e.g. 2*neurons)

    labels: np.ndarray of shape [N] — class labels per sample
    """
    T, N, D = Z.shape
    all_points = []

    for t in range(T):
        if use_tsne:
            from sklearn.manifold import TSNE
            reducer = TSNE(n_components=2, perplexity=30, init='pca', random_state=RANDOM_SEED)
        else:
            reducer = PCA(n_components=2)

        points_2d = reducer.fit_transform(Z[t])

        df = pd.DataFrame({
            "x": points_2d[:, 0],
            "y": points_2d[:, 1],
            "class": labels.astype(str),
            "timestep": t
        })
        all_points.append(df)

    df_all = pd.concat(all_points)

    fig = px.scatter(df_all, x="x", y="y", color="class",
                     animation_frame="timestep",
                     title=title,
                     labels={"class": "Digit Class"},
                     opacity=0.7,
                     width=900, height=700)

    fig.update_traces(marker=dict(size=6), selector=dict(mode='markers'))
    fig.update_layout(template="plotly_dark", transition_duration=100)
    fig.show()


Processing will take about 30 seconds if using PCA, a few minutes if using TSNE.

In [None]:
import numpy as np

# Get the number of timesteps (T) and spatial dimensions (H, W, D)
T = len(z_history)
H, W, D = z_history[0].shape
N = H * W

# Convert each complex tensor in z_history to phase angles,
# Then reshape from (H, W, D) to (N, D) and stack along a new axis.
Z = np.stack([np.angle(z.cpu().numpy()).reshape(N, D) for z in z_history], axis=0)

# Create a labels array; here we assign a dummy label for all samples.
labels = np.full((N,), DIGIT)  # or any appropriate label array

# Now generate the phase plot using the extracted phase angles.
generate_phase_plot(Z, labels, title="Phase Space Trajectory", use_tsne=False)

## Visualise Oscillator Field Dynamics


### Generate Field Phase State Animation

In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
from sklearn.decomposition import PCA


def generate_field_phase_plot(phase_history: list[torch.Tensor], title="Phase State"):
    T = len(phase_history)
    H, W, D = phase_history[0].shape
    theta_images: list[np.ndarray] = [np.angle(z.cpu().numpy()).mean(axis=2) for z in phase_history]

    px.imshow(np.stack(theta_images),
              animation_frame=0,
              title=title,
              labels={"animation_frame": "Timestep"},
              template="plotly_dark").show()


### Visualise Phase State from Simulation History

In [None]:
generate_field_phase_plot(z_history)

### Generate Spike History Plot

In [None]:
import numpy as np
import plotly.express as px

def generate_spike_activity_plot(spike_history: list[torch.Tensor], title="Spike Activity"):
    T = len(spike_history)
    H, W, D = spike_history[0].shape
    # Compute spike images: mean spike activity over feature dimension (D)
    spike_images: list[np.ndarray] = [s.cpu().numpy().mean(axis=2) for s in spike_history]

    px.imshow(np.stack(spike_images),
              animation_frame=0,
              title=title,
              labels={"animation_frame": "Timestep"},
              template="plotly_dark").show()


### Visualise Spike Activity from Simulation History

In [None]:
generate_spike_activity_plot(spike_history)