In [1]:
import jax
import tensorflow as tf
import matplotlib.pyplot as plt
import jax.numpy as jnp
import numpy as np

from jax import jit, vmap
from functools import partial
from params_init import create_conn_matrix
from rk4_ode_solver import solve_ode
from functions import cross_entropy_grad, cross_entropy_loss, accuracy



In [2]:
# MNIST Dataset (training set: 60,000 imgs | test set: 10,000 imgs)
(input_train, label_train), (input_test, label_test) = tf.keras.datasets.mnist.load_data()
# Create one-hot encoding
num_classes = 10  # MNIST has 10 classes (digits 0-9)
label_train_one_hot = jnp.eye(num_classes)[label_train]
label_test_one_hot = jnp.eye(num_classes)[label_test]

In [3]:
# Neural Network dynamics
@jit
def system(t, y, kappa, kappa1, g, J, x_in, p_in):
    N = len(kappa)
    x = y[:N]
    p = y[N:]
    dxdt = -0.5 * (kappa+kappa1) * x + 0.5 * g * (x**2 + p**2) * p + jnp.dot(J, p) - jnp.sqrt(kappa) * x_in
    dpdt = -0.5 * (kappa+kappa1) * p - 0.5 * g * (x**2 + p**2) * x - jnp.dot(J, x) - jnp.sqrt(kappa) * p_in
    return jnp.concatenate([dxdt, dpdt])

In [4]:
# Params update according to Scattering Backpropagation 
@partial(jit, static_argnums=(5,))
def update_weights(vec, x_free, p_free, graph, J, N):
    def compute_dFdJ_jl(j, l):
        dFdJ_jl = jnp.zeros(2 * N)  
        dFdJ_jl = dFdJ_jl.at[j].set(p_free[l])
        dFdJ_jl = dFdJ_jl.at[l].set(p_free[j])
        dFdJ_jl = dFdJ_jl.at[j + N].set(-x_free[l])
        dFdJ_jl = dFdJ_jl.at[l + N].set(-x_free[j])
        
        dFdJ_jl = dFdJ_jl.at[j].set(jnp.where(j == l, dFdJ_jl[j] / 2, dFdJ_jl[j]))
        dFdJ_jl = dFdJ_jl.at[j + N].set(jnp.where(j == l, dFdJ_jl[j + N] / 2, dFdJ_jl[j + N]))
    
        return dFdJ_jl
    
    def body_fn(edge):
        j, l = edge
        dFdJ_jl = compute_dFdJ_jl(j, l)
        dJ_jl = jnp.dot(dFdJ_jl, vec)
        return dJ_jl
    
    updates = vmap(body_fn)(graph)

    return updates


In [5]:
def Evaluation_loop(key, J, input_pixels, lower, upper, tmax, num_steps, g):
    """
    Evaluates the model on the test-set portion from 'lower' to 'upper'.
    Args: 
    - key : random key
    - J : parameters (connectivity matrix)
    - input_pixels : total-number of input pixels (784 for MNIST)
    - lower : lowest index to test in the training set (min is 0 for MNIST)
    - upper : highest index to test in the training set (max is 10,000 for MNIST)
    - tmax : final time for dynamic symulation
    - num_steps : number of steps in RK4 (dt = tmax / num_steps)
    - g : nonlinearity strength
    Returns: cross-entropy-loss, accuracy rate
    """
    N = J.shape[0]
    y0 = jax.random.normal(key, shape=(2 * N,))
    kappa = 1. * jnp.ones((N,))
    kappa1 = 1. * jnp.ones((N,))
    loss = 0
    acc_rate = 0
    dt = tmax / num_steps
    
    def evolution(idx, carry):
        y0, J, loss, acc_rate = carry
        # Use dynamic indexing in place of direct indexing
        input_vec = jax.lax.dynamic_index_in_dim(input_test, idx, keepdims=False)
        target = jax.lax.dynamic_index_in_dim(label_test_one_hot, idx, keepdims=False)

        # Encode inputs
        x_in = jnp.zeros(N)
        p_in = jnp.zeros(N)
        x_in = x_in.at[:input_pixels].add(jnp.reshape(input_vec, -1) / 100) # rescale input pixel in the (0, 2.55) interval

        # Inference Phase
        _ , out = solve_ode(system, y0, (0., tmax), num_steps, dt, kappa, kappa1, g, J, x_in, p_in)
        solution_free = out[-1,:]

        x_free = solution_free[:N]
        y0 = solution_free # update initial condition

        # Compute loss for the current sample
        loss += cross_entropy_loss(x_free[-10:], target)
        prediction = jnp.argmax(x_free[-10:])
        acc_rate += (prediction == jnp.argmax(target))
        return y0, J, loss, acc_rate
    
    y0, J, loss, acc_rate = jax.lax.fori_loop(lower, upper, evolution, (y0, J, loss, acc_rate) )
    loss = loss / (upper - lower) 
    acc_rate /= (upper - lower)
    
    return loss, acc_rate


In [None]:
def Training_loop(key, params_history, input_pixels, num_epochs,  batch_size,  beta, learning_rate, lower, upper, tmax, num_steps, g):
    """
    Evaluates the model on the test-set portion from 'lower' to 'upper'.
    Args: 
    - key : random key
    - params_history : list of parameters after each training epoch
    - input_pixels : total-number of input pixels (784 for MNIST)
    - num_epochs : number of training epochs
    - batch_size : used for averaging (approximate) gradients with SGD
    - beta : perturbation strength (for the Feedback Phase)
    - learning_rate : learning rate
    - lower : lowest index to test in the training set (min is 0 for MNIST)
    - upper : highest index to test in the training set (max is 60,000 // batchsize for MNIST)
    - tmax : final time for dynamic symulation
    - num_steps : number of steps in RK4 (dt = tmax / num_steps)
    - g : nonlinearity strength
    Returns: cross-entropy-loss, accuracy rate
    """
    loss_history = [] 
    test_loss_history = []
    acc_history = []
    test_acc_history = []
    J = params_history[0]
    N = J.shape[0]
    
    kappa = 1. * jnp.ones((N,))
    kappa1 = 1. * jnp.ones((N,))
    
    sigma_x = jnp.block([
    [jnp.zeros((N,N)), jnp.eye(N)],
    [jnp.eye(N), jnp.zeros((N,N))]
    ])

    # Take note of the upper-triangular nonzero entries' indexes
    graph = jnp.stack(jnp.where(J != 0)).T
    dt = tmax / num_steps
    
    batches = jnp.array([jnp.arange(idx * batch_size, (idx + 1) * batch_size) for idx in range(lower, upper)])

    def evolution(idx, carry):
        epoch_loss, epoch_acc, J = carry

        def single_input_ev(j):
            y0 = jax.random.normal(key, shape=(2 * N,))
            # Use dynamic indexing in place of direct indexing
            input_vec = jax.lax.dynamic_index_in_dim(input_train_shuffled, j, keepdims=False)
            target = jax.lax.dynamic_index_in_dim(label_train_one_hot_shuffled, j, keepdims=False)

            # Encode inputs
            x_in = jnp.zeros(N)
            p_in = jnp.zeros(N)
            x_in = x_in.at[:input_pixels].add(jnp.reshape(input_vec, -1) / 100) # rescale input pixel in the (0, 2.55) interval

            # Inference Phase
            _ , out = solve_ode(system, y0, (0., tmax), num_steps, dt, kappa, kappa1, g, J, x_in, p_in)
            solution_free = out[-1,:]

            x_free = solution_free[:N]
            p_free = solution_free[N:]
            y0 = solution_free 

            # Feedback Phase
            p_in = p_in.at[-10:].add(beta * cross_entropy_grad(x_free[-10:],target) ) # inject error signal
            _ , out = solve_ode(system, y0, (0., tmax/2), num_steps//2, dt, kappa, kappa1, g, J, x_in, p_in)
            solution_perturbed = out[-1,:]

            # Update weights
            # Compute right-part of SB's gradient approx formula 
            # (Note that U' = sigma_x corresponds to the quasi-symmetry U=sigma_y in the a(t)-basis)
            vec = (learning_rate) * sigma_x @ (solution_perturbed - solution_free) / beta 

            return update_weights(vec, x_free, p_free, graph, J, N), cross_entropy_loss(x_free[-10:], target), accuracy(x_free, target)

        updates, losses, accs = vmap(single_input_ev)(batches[idx]) 

        average_updates = jnp.mean(updates, axis = 0)

        J = J.at[graph[:,0], graph[:,1]].add(average_updates)

        epoch_loss += jnp.sum(losses) 
        epoch_acc += jnp.sum(accs)
        
        return epoch_loss, epoch_acc, J
    
    best_acc = 0

    # Training Loop
    for epoch in range(num_epochs): 
        epoch_loss = 0
        epoch_acc = 0
        # Shuffle training set
        key, subkey = jax.random.split(key)
        indices = jax.random.permutation(subkey, upper*batch_size)
        input_train_shuffled = input_train[indices]
        label_train_one_hot_shuffled = label_train_one_hot[indices]
        # Training
        epoch_loss, epoch_acc, J = jax.lax.fori_loop(lower, upper, evolution, (epoch_loss, epoch_acc, J) )

        # Average loss for the epoch
        epoch_loss = epoch_loss / ( (upper - lower) * batch_size )
        epoch_acc = epoch_acc / ( (upper - lower) * batch_size )
        loss_history.append(epoch_loss)
        acc_history.append(epoch_acc)
        # Learning rate update
        ###
        # Compute test accuracy
        key, subkey = jax.random.split(key)       
        epoch_test_loss, epoch_test_acc = Evaluation_loop(key = subkey, J = J, input_pixels = input_pixels, lower = 0, upper = 10000, tmax = tmax, num_steps = num_steps, g = g)

        # Save the best weights
        if epoch_test_acc > best_acc:
            best_acc = epoch_test_acc
            params_history[0] = J

        test_loss_history.append(epoch_test_loss)
        test_acc_history.append(epoch_test_acc)
        print(f"End of Epoch {epoch+1}: Train loss = {epoch_loss}, Train Accuracy = {epoch_acc}, Test loss = {epoch_test_loss},  Test Accuracy = {epoch_test_acc}, ||J||_2 = {jnp.linalg.norm(J)}", flush=True)

    return loss_history, acc_history, test_loss_history, test_acc_history, params_history, J

In [None]:
key = jax.random.key(41) 
key, subkey1, subkey2 = jax.random.split(key, num=3)

# Kernel dimension
kernel_sizes = [6,4]
strides = [2,2]

J = create_conn_matrix(subkey1, kernel_sizes = [6,4], strides = [2,2], input_pixels = 28, num_logits = 10)
#print(f"||J||_2 = {jnp.linalg.norm(J)}", flush=True)

params_history = [J]

input_pixels = 784
num_epochs = 30
batch_size = 10
beta = 0.01
learning_rate = 0.1

# Interval [lower, upper*batch_size] of images in the training set to use for the training
lower = 0
upper = 60000 // batch_size

tmax = 60.
num_steps = 600
g = 0.2

train_loss_history, train_acc_history, test_loss_history, test_acc_history, params_history, J = Training_loop(subkey2, 
                                                                                  params_history, 
                                                                                  input_pixels, 
                                                                                  num_epochs, 
                                                                                  batch_size, 
                                                                                  beta, 
                                                                                  learning_rate, 
                                                                                  lower, 
                                                                                  upper, 
                                                                                  tmax, 
                                                                                  num_steps, 
                                                                                  g)



In [None]:
# Plot the loss graph
plt.figure(figsize=(8, 5))
plt.plot(train_acc_history, label='Train Accuracy')
plt.plot(test_acc_history, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Mean Squared Error')
plt.title('Training Loss Over Epochs')
plt.legend()
plt.grid()
plt.show()