In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from scipy.special import softmax
from scipy.spatial.distance import euclidean

# MS: TODO comment out for now
#import biomart

import umap
import pickle
import scipy.spatial as sp
import seaborn as sns
import itertools

from sklearn.decomposition import PCA
from sklearn.cluster import AgglomerativeClustering
from sklearn.datasets import load_iris

from scipy.spatial.distance import pdist,squareform
from scipy.cluster import hierarchy

import glob
import torch

# Goal = Attempt to recreate Fig. 2 of biorxiv manuscript

Utility functions

In [None]:
""" code from Karin2024
def run(self, x0, w, beta, dt=0.1, tmax=10):
    x=x0.copy()
    hist = []
    for t in np.arange(0,tmax,dt):
        hist.append(x.copy())
        x+=dt*(np.matmul(self.Q.T,softmax(w+beta*np.matmul(self.XI,x)))-x)
        
    hist = np.array(hist)
    return x,hist
"""

def run_traj(f_of_xtp, x0, params, t0=0.0, tmax=10.0, dt_max=0.1):
    # TODO have local function for x_t+1 = foo(x_t) which is autograd-able
    
    # if euler
    times_traj = np.arange(t0, tmax + dt_max, dt_max)    # suze T
    x_traj = np.zeros((len(times_traj), x0.shape[0]))    # size T x N
    
    x_traj[0, :] = x0
    for idx, tval in enumerate(times_traj[:-1]): 
        current_vel = f_of_xtp(x_traj[idx, :], tval, params) 
        x_traj[idx+1, :] = x_traj[idx, :] + dt_max * current_vel
        
    # if scipy integrator  maybe writeup runge kutta
    # TODO implement radau call
    
    return x_traj, times_traj

## Silly example testing the generalic trajectory function: 
## - launch particle, only force is gravity $d v_y / dt = -g$

In [None]:
# Note: keep signature of function fixed for now as x, t, p - state [dim N], times [scalar], params [dim p]
# TODO jit all func, or torch for eventual Autodiff

def dxdt_particle(x, t, params):
    """
    # Throw a ball in R^2
    # Three params: vx0, vy0, gravity
    """
    vx0, vy0, g = params
    dxdt = vx0
    dydt = vy0 - g*t
    return np.array([dxdt, dydt])

# sample call
dxdt_particle([10,10], 0, [1, 1, 9.8])

In [None]:
x0 = np.array([10,100])
params = [1, 10, 9.8]
x_traj, times_traj = run_traj(dxdt_particle, x0, params, t0=0.0, tmax=2, dt_max=0.1)

In [None]:
fig, axarr = plt.subplots(2, 1, sharex=True)
pltstyle = dict(linestyle='--', marker='o')
axarr[0].plot(times_traj, x_traj[:,0], **pltstyle); axarr[0].set_title('x vs t') 
axarr[1].plot(times_traj, x_traj[:,1], **pltstyle); axarr[1].set_title('y vs t')
plt.suptitle('Simple example: launch particle with gravity - euler method')
plt.show(); plt.close('all')

# Classic (quadratic) hopfield network with continuous variables $x \in \mathbb{R}^N$

My suggested form
### $\frac{dx}{dt} = -x + \textrm{tanh}(\beta [Jx + b])$

where 

- $\beta=1/T$ is inverse temperature -- at zero temperature limit its digital (sign function)
- $J=\xi \xi^T$ is defined by the $N \times K$ matrix of cell types, with cell type $\mu=1, ..., K$ indexing the cell types represented by columns of $\xi$ 
- $b \in \mathbb{R}^N$ is an $N$ dimensional applied field on specific genes; we will reserve $w \in \mathbb{R}^K$ to denote an analogous forcing applied in the direction of the $K$ encoded cell types

The form in Hopfield, 1984 (PNAS):
### $\frac{dx}{dt} = -x + J \:\textrm{tanh}(\beta x) + b$

#### Remark 1: on representations
The main dynamics stated at the top are not unique. The nonlinearity can be applied before or after the linearity. Here is a general recipe for a recurrent dynamics in $\mathbb^N$ combining an affine transform and a static elementwise nonlinearity $\sigma(z)$:
- Representation 1:  $dx/dt = -x + \sigma(Ax + b)$
- Representation 2:  $dy/dt = -y + A\sigma(y) + b$

I tend to prefer Option 1, but they are sometimes treated interchangeably in the literature. Consider specifically the linear change of variables $y=Ax+b$. When $A$ is invertible, the two representations are equivalent. Otherwise, the situation is more subtle: 
- Going from Rep (1) to Rep (2) works even when $A$ is not full rank. 
- However, if $A$ is not full rank, starting from Rep (2) one comes to $A\: dx/dt = A (-x + \sigma(Ax+b))$. Consider the thin SVD of $A$, corresponding to the first sum in $A= \sum_{i=1}^p \sigma_i u_i v_i^T + \sum_{i=p+1}^N 0 \cdot u_i v_i^T$, accounting for the $p<N$ non-zero singular values ($p = N$ would imply invertible). In matrix form $A=UDV^T$ with $U, V \in \mathbb{R}^{N \times p}$. Left multiplying by the pseudoinverse for $U$ and inverse for $D$ gives the reduced dynamics for $m=V^T x$ in $p<N$ dimensions:  $dm/dt = -m + V^T \:\sigma(U D m + b)$ -- is it interesting that this $p$-dim dynamics is equivalent to the $N>p$ dim dynamics? 

I also tend to use the nonlinearity $\sigma(z) = \textrm{tanh}(\beta z)$. This choice is inspired by the mean-spin flip update rule of the binary Ising model with (stochastic) Glauber dynamics. Combined with Option 1, it has a nice interpretation that any fixed point lies in or on the $\pm 1$ hypercube ($\equiv \Omega_N$); more generally, note that $\Omega_N$ is positively invariant for this choice of $\sigma(z)$. To "think outside the box", one could consider alternative nonlinearities like $\sigma(z)=\textrm{ReLU}(z)$. 

#### Remark 2: Alternative RBM-like two-layer form see [Krotov and Hopfield, 2021, ICLR] (TODO q: when is it equivalent?) :
- $\tau_x\: \frac{dx}{dt} = -x + W_x \,f(m) + b$
- $\tau_m\: \frac{dm}{dt} = -m + W_m \,g(x) + w$

Notation: 

- Let $W_x \in \mathbb{R}^{N \times K}$ be a matrix of connections from "hidden" units $m \in \mathbb{R}^K$ to "visible" units $x \in \mathbb{R}^N$. Let $W_m \in \mathbb{R}^{K \times N}$ denote reverse connections from the visible to the hidden units. The authors assume symmetric connections between the layers, i.e. $W_x=\xi$ and $W_m=\xi^T$. 
- $f:\mathbb{R}^K \rightarrow \mathbb{R}^N$, $g:\mathbb{R}^N \rightarrow \mathbb{R}^K$ are potentially nonlinear functions of a given neuron (e.g. $f_{\mu} \equiv \textrm{tanh}(m_{\mu})$  or the entire layer (e.g. in the case of $\textrm{softmax}$). 

##### Comment on relative timescales for the two-layer dynamics
If one timescale ($\tau_x, \tau_m$) is much faster (i.e. small $\tau$), then the associated dynamics can be eliminated (treated as being at "quasi-steady-state"). In that case, the coupled dynamics can be viewed as a "singular expansion" of the uncoupled form. 

E.g., if $\tau_m \ll \tau_x$, then $m(t) \approx \xi^T g(x) + w$ at all times (since any deviation will rapidly be corrected, relatively speaking).

As a concrete example, the following choice of $f$, $g$ maps to a classic hebbian Hopfield network (TODO check this vs. form above; if $b$ is outside the $\textrm{tanh}$ then state can leave the $\pm 1$ hypercube): 
- $\tau_x\: \frac{dx}{dt} = -x + \xi \:\textrm{tanh}(\beta m) + b$
- $\tau_m\: \frac{dm}{dt} = -m + \xi^T x + w$
 
Assuming  $\tau_m \ll \tau_x$ and making the QSS substitution gives the condensed dynamics $\tau_x\: dx/dt = -x + \xi \:\textrm{tanh} \left[ \beta (\xi^T x + w) \right] + b$.

#### Additional notes from recent [Krotov and Hopfield, 2021, ICLR]
TODO...
  

In [None]:
def build_xi_rolling_correlated(N, p=3):
    """
    N: number of visible units
    p: number of patterns (column vectors)
    """
    assert N % p == 0
    a = -1 * np.ones(N)
    a[0:p] = 1
    
    xi = np.zeros((N, p))
    for idx in range(p):
        xi[:, idx] = np.roll(a, p * idx, axis=None)
    return xi

def build_xi_random_binary(N, p, seed=0):
    """
    N: number of visible units
    p: number of patterns (column vectors)
    """
    np.random.seed(seed)
    xi = 2*np.random.randint(2, size=(N,p)) - 1
    return xi

In [None]:
N_rand, K_rand = 5000, 4
xi_rand = build_xi_random_binary(N_rand, K_rand)
print('Corr. matrix for K=%d random patterns in N=%d dims' % (K_rand, N_rand))
print(xi_rand.T @ xi_rand / N_rand)

print('\nLow dim correlated example (N=%d, K=%d)' % (K_rand, N_rand))
N_simple, K_simple = 9, 3
xi_9_3_corr = build_xi_rolling_correlated(N_simple, p=K_simple)
#print(xi_9_3_corr)
print(xi_9_3_corr.T @ xi_9_3_corr / xi_9_3_corr.shape[0])

In [None]:
# Note: keep signature of function fixed for now as x, t, p - state [dim N], times [scalar], params [dim p]
# TODO jit all func, or torch for eventual Autodiff

#@torch.jit
def dxdt_HN_quadratic_hebb_A(x, t, params):
    """
    - My preferred form: arises from mean spin flip rule for stochastic discrete updates
    - This form works roughly as expected
    """
    beta, xi = params
    arg_of_sigma = xi @ xi.T @ x
    dxdt = -x + np.tanh(beta * arg_of_sigma)
    return dxdt

def dxdt_HN_quadratic_hebb_B(x, t, params):
    """
    Issues with this form, which arises from RBM-type interpretation 
    - TODO check implementation and issues...
    """
    beta, xi = params
    normalization = 1 / xi.shape[0]
    arg_of_sigma = xi.T @ x
    dxdt = -x + normalization * xi @ np.tanh(beta * arg_of_sigma)
    return dxdt

def dxdt_HN_quadratic_hebb_C(x, t, params):
    """
    This form is closer to Hopfield 1984
    - See spurious minima when point in vicinity of pattern 0 is given small linear push towards other other patterns
    """
    beta, xi = params
    arg_of_sigma = x
    normalization = 1 / xi.shape[0]
    dxdt = -x + normalization * xi @ xi.T @ np.tanh(beta * arg_of_sigma)
    return dxdt

# sample call
local_beta = 10.0
local_xi = xi_rand
local_N, local_K = local_xi.shape
params = (local_beta, local_xi)

Get trajectories and plot the projection onto pattern subspace + norm of vector field

In [None]:
np.random.seed(0)
x0 = np.random.rand(local_N)

In [None]:
print(x_traj.shape)

methods_str = ['(A)', '(B)', '(C)']
methods_foo = [dxdt_HN_quadratic_hebb_A, dxdt_HN_quadratic_hebb_B, dxdt_HN_quadratic_hebb_C]
methods_traj_x = [0] * len(methods_str)
methods_traj_t = [0] * len(methods_str)

for idx in range(3):
    x_traj, times_traj = run_traj(methods_foo[idx], x0, params, t0=0.0, tmax=4, dt_max=0.1)
    methods_traj_x[idx] = x_traj
    methods_traj_t[idx] = times_traj

In [None]:
fig, axarr = plt.subplots(3, 2, sharex=True, figsize=(8,9))
pltstyle = dict(linestyle='--', marker='o', markersize=4)
for idx in range(3):
    axarr[idx, 0].plot(methods_traj_x[idx] @ params[1] / local_N,  **pltstyle)
    axarr[idx, 0].set_title('HN classic quadratic %s' % methods_str[idx])
    axarr[idx, 0].axhline(0, linestyle='--', linewidth=2, c='k')
    
    dxdt_arr = methods_foo[idx](methods_traj_x[idx].T, 0, params)
    axarr[idx, 1].plot(np.linalg.norm(dxdt_arr, axis=0),  **pltstyle)
    axarr[idx, 1].set_title('Norm of $dx/dt$ for %s' % methods_str[idx])
    axarr[idx, 1].axhline(0, linestyle='--', linewidth=2, c='k')
    
axarr[-1, 0].set_xlabel('t')
axarr[-1, 1].set_xlabel('t')
plt.suptitle('Example: HN classic quadratic\n' + r'Trajectory $N^{-1} \xi^T x(t)$ from random IC (N=%d, K=%d)' % (local_N, local_K))
plt.show(); plt.close('all')

In [None]:
methods_traj_x[1][-1, :] @ params[1] / local_N

### Classic HN variants: Repeat but for non-random IC (in vicinity of pattern 0)

In [None]:
coeffs_rand = np.random.normal(loc=0, scale=0.5, size=local_K - 1)
print(coeffs_rand)

x0_near_pattern_0 = params[1][:, 0] + coeffs_rand @ params[1][:, 1:].T
print('overlaps:', params[1].T @ x0_near_pattern_0 / local_N)

In [None]:
print(x_traj.shape)

methods_str = ['(A)', '(B)', '(C)']
methods_foo = [dxdt_HN_quadratic_hebb_A, dxdt_HN_quadratic_hebb_B, dxdt_HN_quadratic_hebb_C]
methods_traj_x = [0] * len(methods_str)
methods_traj_t = [0] * len(methods_str)

for idx in range(3):
    x_traj, times_traj = run_traj(methods_foo[idx], x0_near_pattern_0, params, t0=0.0, tmax=4, dt_max=0.1)
    methods_traj_x[idx] = x_traj
    methods_traj_t[idx] = times_traj

In [None]:
fig, axarr = plt.subplots(3, 2, sharex=True, figsize=(8,9))
pltstyle = dict(linestyle='--', marker='o', markersize=4)
for idx in range(3):
    axarr[idx, 0].plot(methods_traj_x[idx] @ params[1] / local_N,  **pltstyle)
    axarr[idx, 0].set_title('HN classic quadratic %s' % methods_str[idx])
    axarr[idx, 0].axhline(0, linestyle='--', linewidth=2, c='k')
    
    dxdt_arr = methods_foo[idx](methods_traj_x[idx].T, 0, params)
    axarr[idx, 1].plot(np.linalg.norm(dxdt_arr, axis=0),  **pltstyle)
    axarr[idx, 1].set_title('Norm of $dx/dt$ for %s' % methods_str[idx])
    axarr[idx, 1].axhline(0, linestyle='--', linewidth=2, c='k')
    
axarr[-1, 0].set_xlabel('t')
axarr[-1, 1].set_xlabel('t')
plt.suptitle('Example: HN classic quadratic\n' + r'Trajectory $N^{-1} \xi^T x(t)$ from IC near $\xi^{(0)}$ (N=%d, K=%d)' % (local_N, local_K))
plt.show(); plt.close('all')

In [None]:
#dxdt_arr = methods_foo[0](np.zeros(5000), 0, params)
dxdt_arr = dxdt_HN_quadratic_hebb_A(methods_traj_x[0][10,:], 0, params)
#methods_foo
print(dxdt_arr)
print(dxdt_arr.shape)
print(np.sqrt(dxdt_arr.T @ dxdt_arr))

dxdt_arr = dxdt_HN_quadratic_hebb_A(params[1][:,0], 0, params)
#methods_foo
print(dxdt_arr)
print(dxdt_arr.shape)
print(np.sqrt(dxdt_arr.T @ dxdt_arr))

### Now let's implement the modern Hopfield network vector field  
$dx/dt = ...$

In [None]:
# Note: keep signature of function fixed for now as x, t, p - state [dim N], times [scalar], params [dim p]
# TODO jit all func, or torch for eventual Autodiff

def dxdt_particle(x, t, params):
    """
    # Throw a ball in R^2
    # Three params: vx0, vy0, gravity
    """
    vx0, vy0, g = params
    dxdt = vx0
    dydt = vy0 - g*t
    return np.array([dxdt, dydt])

# sample call
dxdt_particle([10,10], 0, [1, 1, 9.8])

In [None]:
x0 = np.array([10,100])
params = [1, 10, 9.8]
x_traj, times_traj = run_traj(dxdt_particle, x0, params, t0=0.0, tmax=2, dt_max=0.1)

In [None]:
fig, axarr = plt.subplots(2, 1, sharex=True)
pltstyle = dict(linestyle='--', marker='o')
axarr[0].plot(times_traj, x_traj[:,0], **pltstyle); axarr[0].set_title('x vs t') 
axarr[1].plot(times_traj, x_traj[:,1], **pltstyle); axarr[1].set_title('y vs t')
plt.suptitle('Simple example: launch particle with gravity - euler method')
plt.show(); plt.close('all')