# Linear pRNN dynamics

By [Marcus Ghosh](https://profiles.imperial.ac.uk/m.ghosh/)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ghoshm/pRNN_tutorials/blob/main/linear_dynamics_tutorial.ipynb)

## Aim

We're going to explore how a neural network's structure shapes its dynamics:   
* **Structure**: how neurons are connected — the weight matrix. 
* **Dynamics**: how each neuron's activity (activation or output) evolves over time when you run the network.  

There are 4 parts to this tutorial: 
* Understand the **models** - how can we build networks with different structures and define their dynamics. 
* **Simulate** model dynamics (numerically) and visualize their behaviour.  
* **Solve** model dynamics (analytically). 
* **Extensions** - explore oscillatory dynamics, and more complex models.    

Throughout instructions and questions are marked like this: 

> 0. Read, code or answer a question.

When you need to fill in code, you will see a  ```pass``` statement. Replace this with your code! 

There are **10** instructions for you to tackle! 

# Setup

In [None]:
# Imports
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib import colors
import itertools
import os 

# For Google Colab
if not os.path.exists('src.py'):
  !git clone https://github.com/ghoshm/pRNN_tutorials.git
  %cd pRNN_tutorials

from src import * # Import all functions from src.py

plt.style.use("./style_sheet.mplstyle")

eps = 1e-12

# Models

To explore structure and dynamics, we need to create networks with different structures and then simulate or solve their dynamics.

## Structure
Instead of creating artificial neural networks with different architectures, we will build *toy* models.   

> 1. Read this [short article](https://doi.org/10.53053/WUSL4267). Then discuss with your partner - what are toy models, what are their advantages compared to larger models, and what might we lose by focussing on them? 

We'll start by building models with:
* Three neurons (nodes or units) - labelled input (*i*), hidden (*h*), output (*o*). While three neurons may seem small, you'll see that this size allows us to explore a complete set of structures, with surprisingly diverse dynamics.  
* A <span style='color: #59656d;'>feedforward</span> connection from the input to the hidden neuron (*ih*). 
* A <span style='color: #59656d;'>feedforward</span> connection from the hidden to the output neuron (*ho*). 

This is a *feedforward* structure - as signals can only travel from the input to the output neuron.

But, there are $7$ other connections we could add, each of which will change the models structure and its signal flow: 
* <span style='color: #0189a0;'>Lateral (or recurrent)</span> connections from each neuron to itself: *ii*, *hh*, *oo*. 
* <span style='color: #13bbaf;'> Skip</span> connections which bypass the hidden neuron: *io*, *oi*. 
* <span style='color: #f97306;'> Backwards</span> connections: *hi*, *oh*. 

This figure shows our model with its: 
* $3$ nodes.
* $2$ feed-forward connections - in grey.
* All $7$ of the other possible connections: lateral - blue, skip - green, backwards - orange.

<img src="./images/FR_pathways_labs.png" width="400" >

By keeping the $2$ feed-forward connections (so that signals can always travel from the input to the output), and allowing each of the other connections to be present or absent in any combination ($7$ binary choices), we can generate:

$2^{(3+2+2)} = 2^7 = 128$ unique network structures 

In [Ghosh & Goodman, 2025](https://doi.org/10.1101/2025.07.28.667142) we term these *partially recurrent neural networks* (pRNNs). 

Running the code below will:
* Create a list with an adjacency matrix per structure. These are binary $(3,3)$ matrices with connections $(1)$ between (source, target). So W[src, tgt] = 1 means: an connection from neuron src → neuron tgt.
 
* Plot all $128$ pRNN structures - with their $3$ neurons (circles) and connections (arrows).

In [None]:
# Define all pRNN structures 

wm_flags = np.array(list(itertools.product([0,1], repeat=7)), dtype=int) # array: (architectures, extra connection pathways)

adj_ms = [create_pRNN_adj_matrix(wm_flag=f) for f in wm_flags] # list: architectures (source, target).             

# Plot 
fig, ax = plt.subplots(nrows=8, ncols=16, figsize=(30,15), sharex=True, sharey=True)
for a, _ in enumerate(ax.ravel()):
    plt.sca(ax.ravel()[a])
    plot_pRNN_architecture(wm_flags[a], ax=ax.ravel()[a])

## Dynamics

Instead of allowing each neuron to compute a non-linear input-output transformation (such as ReLU - in machine learning or leaky integrate-and-fire in computational neuroscience) we will treat our models as *linear dynamical systems*. 

At every time-step ($t$) a network's state is defined by a vector ($x_t$) with each node's activation:

$$
\begin{aligned}
x_t = [i_t, h_t, o_t]
\end{aligned}
$$

To determine how this state changes over time, we define an update:

$$ 
x_{t+1} = x_tW
$$ 

Where $W$ is the network's ($3,3$) adjacency matrix.

Even though these models are linear, activity may grow, persist, decay or oscillate depending on the structure! 

> 2. Calculate $x_1$, $x_2$ and $x_3$ by hand, after starting the purely feed-forward model (which only has *ih* and *ho* connections) from the state $x_0 = [1,0,0]$ - which is like inputting a signal to the network. How does the network's state evolve over time?

# Simulations

One approach to exploring each structure's dynamics is to use numerical simulations. 

To do so, we just need to: 
* Set an initial state ($x_0$).  
* Run the state update (above) for a fixed number of time steps ($T$).  

In [None]:
def simulate_pRNN_dynamics(adj_ms, x0, T): 
    """
    Simulate dynamics from an initial state. 
    Arguments: 
        adj_ms: list of weight matrices (source, target).
        x0: an initial network state, array: (nodes).   
        T: the number of steps to simulate for, int. 
    Returns: 
        x_hist: an array of dynamics, (architectures, time, nodes). 
    """
    x_hist = np.zeros((len(adj_ms), T + 1, len(x0)), dtype=float) # array: (architectures, time, nodes)

    for a, W in enumerate(adj_ms):
        x = x0.copy() # current state
        x_hist[a, 0] = x # store

        for t in range(T):
            x = x @ W # update current state
            x_hist[a, t + 1] = x # store

    return x_hist

x0 = np.array([1.0, 0.0, 0.0]) # initial state to simulate from (input, hidden, output)
T = 15 # number of time steps to simulate

x_hist = simulate_pRNN_dynamics(adj_ms, x0, T)

> 3. Check that the activity in the feed-forward model `x_hist[0]` matches your answer above. If not, what did you get wrong?

Now let's plot each model's dynamics. The function below will plot $128$ subplots ($1$ per model). Within each subplot:
* The x-axis is time.
* The y-axis is activity. 
* There are $3$ lines - representing the activity of the <span style='color: #d8dcd6;'> input</span> (grey), <span style='color: #7e1e9c;'> hidden</span> (purple) and <span style='color: #1ef876;'> output</span> (green) neuron over time. 

In [None]:
# Figure 
plot_pRNN_dynamics(x_hist=x_hist)

In the figure above it is hard to compare dynamics across architectures! 

One reason for this is that in different models, neurons can receive very different amounts of input (depending on their number of connections).

For example, in the least complex model (`adj_ms[0]`) the output neuron receives one input. While, in the most complex model (`adj_ms[-1]`) the output unit receives three inputs! 

To rectify this, we can normalise the adjacency matrices in `adj_ms`. 

> 4. In the cell below, normalise each adjacency matrix so that each neuron receives either no input ($0$) or a total input of $1$.   

In [None]:
# Normalise 
for a, W in enumerate(adj_ms):
    pass # Fill in your code here!

# Simulate
x_hist = simulate_pRNN_dynamics(adj_ms, x0, T)

# Figure 
plot_pRNN_dynamics(x_hist=x_hist)

If you have done this correctly, you will now see very different dynamics across the different models!

> 5. In the cell below use calculate the number of structures with unique dynamics (from  ```x_hist```).

In [None]:
pass # Fill in your code here! 

# Solvable

Another way to understand each pRNNs behaviour is to look at it [eigendecomposition](https://www.datacamp.com/tutorial/eigendecomposition). 

This method factors (decomposes) a square matrix ($A$) into three matrices. 

$$ 
\begin{aligned}
A = PDP^{-1}
\end{aligned}
$$

Where:
* $P$ - contains eigenvectors (the characteristic "modes" of the system). 
* $D$ - contains eigenvalues ($\lambda$) on the diagonal. 
* $P^{-1}$ - the inverse of $P$.  

For our networks, the eigenvectors correspond to activity patterns that evolve over time, and the eigenvalues tell us how each pattern changes:

* $|\lambda|$ > 1: grows.
* $|\lambda|$ ~ 1: persists.
* $|\lambda|$ < 1: decays.
* complex $\lambda$: oscillations.


In [None]:
# Eigendecomposition 

eigvals_list = []
eigvecs_list = []
for W in adj_ms:
    vals, vecs = np.linalg.eig(W.T) # transpose as we use a row-vector update.
    eigvals_list.append(vals)
    eigvecs_list.append(vecs)

eigvals = np.array(eigvals_list) # array: (128, 3)
eigvecs = np.array(eigvecs_list) # array: (128, 3, 3)
# Note: eigvecs[a, :, b] is the eigenvector corresponding to eigvals[a, b]

The long‑term behaviour of each network is dictated by its spectral radius: 

$$
\begin{aligned}
\rho (W) = \max_{i} |\lambda_i|
\end{aligned}
$$

> 6. In the cell below calculate each structure's spectral radius.    

In [None]:
# Analysis

gs = [] # a list to store each architectures spectral radius

for a in range(len(eigvals)):
    pass # Fill in code here

gs = np.array(gs) # array: 128 

The code below will plot the spectral radius (y-axis) of each architecture (x-axis) sorted from lowest to highest. 

And will highlight the <span style='color: #5a86ad;'> feedforward model</span> (in blue), the <span style='color: #1fb57a;'> RNN model</span> (in green) and the <span style='color: #7e1e9c;'> fully recurrent model </span> (in purple). 

In [None]:
# Figure  

interest = [0, 32, 127]
i_cols = ['xkcd:dusty blue', 'xkcd:dark seafoam', 'xkcd:purple']
i_labels = ['Feedforward', 'RNN', 'Fully recurrent']

fig = plt.figure(figsize=(10,5))
plt.scatter(range(len(gs)), np.sort(gs), s=2, c='k')

for a, i in enumerate(interest): 
    ms, sl, bl = plt.stem(np.where(np.argsort(gs) == i), gs[i])
    ms.set_markersize(2)
    ms.set_markerfacecolor(i_cols[a])
    ms.set_markeredgecolor(i_cols[a])
    sl.set_linewidth(2)
    sl.set_edgecolor(i_cols[a])

plt.xticks([])
plt.xlabel("Architectures")
plt.ylabel(r"$\rho(W)$")

Assuming everything is correct, you should see that for most architectures $\rho (W) = 1$, but for some $\rho (W) < 1$. If not, you may need to check your normalisation or how you calculate $\rho (W)$.    

The figure below will plot every *pRNN* structure. Highlighting those with $\rho (W) < 1$ in green. 

In [None]:
# Figure 
fig, ax = plt.subplots(
    nrows=8, ncols=16, figsize=(30, 15),
    sharex=True, sharey=True
)

for a, _ in enumerate(ax.ravel()):
    plt.sca(ax.ravel()[a])
    if gs[a] < 0.99:
        plot_pRNN_architecture(wm_flag=wm_flags[a], ax=ax.ravel()[a], color="xkcd:dark seafoam")
    else: 
        plot_pRNN_architecture(wm_flag=wm_flags[a], ax=ax.ravel()[a], color="xkcd:light grey")

    if wm_flags[a][1] == 1: # A first guess at a structural rule 
        plt.scatter(1,1, color='xkcd:light purple')

> 7. What feature or features do all of the structures with $\rho (W) < 1$ share? Can you write a *structural rule* - which identifies these? 

The code above tries the rule: structures with *hh* connections, and adds a purple dot to those structures. By comparing the green structures (those with $\rho (W) < 1$) and those with a purple dot (with *hh* connections), we can see that this rule isn't very predictive. Try to discover the correct rule!   

# Extensions

Now you've got the hang of thinking about structure and dynamics, try working on these extensions (in any order):

8. **Oscillations** - some pRNNs will show oscillatory dynamics. You can find these by looking at your simulation results or by finding structures with complex eigenvalues. Can you find a *structural rule* (as above) which predicts which structures will oscillate and which will not? 

9. **Wider structures** - by adding an extra input neuron (*i0*, *i1*, *h*, *o*), keeping a similar feed-forward structure (*i0->h*, *i1->h*, *h->o*) and allowing all of the other possible connections to be present or absent in different combinations you could generate a larger space of $8,192$ structures. Generate these models and explore their dynamics. 

10. **Deeper structures** - by adding an extra hidden neuron (*i*, *h0*, *h1*, *o*), keeping a similar feed-forward structure (*i->h0*, *h0->h1*, *h1->o*) and allowing all of the other possible connections to be present or absent in different combinations you could generate a larger space of $8,192$ structures. Generate these models and explore their dynamics.  