In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from scipy.special import softmax
from scipy.spatial.distance import euclidean

# MS: TODO comment out for now
#import biomart

import umap
import pickle
import scipy.spatial as sp
import seaborn as sns
import itertools

from sklearn.decomposition import PCA
from sklearn.cluster import AgglomerativeClustering
from sklearn.datasets import load_iris

from scipy.spatial.distance import pdist,squareform
from scipy.cluster import hierarchy

import glob
import torch

import os

In [None]:
NB_OUTPUT = 'output'
os.makedirs(NB_OUTPUT, exist_ok=True)

# Goal = Attempt to recreate Fig. 2 of biorxiv manuscript

Utility functions

In [None]:
""" code from Karin2024
def run(self, x0, w, beta, dt=0.1, tmax=10):
    x=x0.copy()
    hist = []
    for t in np.arange(0,tmax,dt):
        hist.append(x.copy())
        x+=dt*(np.matmul(self.Q.T,softmax(w+beta*np.matmul(self.XI,x)))-x)
        
    hist = np.array(hist)
    return x,hist
"""
from scipy.integrate import solve_ivp


USE_SCIPY_INTEGRATOR = True # f False, use Euler integration (manual)


def run_traj(f_of_txp, x0, params, t0=0.0, tmax=10.0, dt_max=0.1):
    # TODO have local function for x_t+1 = foo(x_t) which is autograd-able
    
    # if scipy integrator  maybe writeup runge kutta
    if USE_SCIPY_INTEGRATOR:
        # https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html
        # methods: RK45 [=default], Radau, ...
        sol = solve_ivp(f_of_txp, [t0, tmax], x0, args=[params], method='RK45', dense_output=True)
        #x_traj = sol.y.T
        #times_traj = sol.t
        
        # now replace direct output with with dense_output
        times_traj = np.arange(t0, tmax + dt_max, dt_max)    # size T
        x_traj = sol.sol(times_traj).T                       # size T x N
        
    else: 
        # simple euler integration
        times_traj = np.arange(t0, tmax + dt_max, dt_max)    # size T
        x_traj = np.zeros((len(times_traj), x0.shape[0]))    # size T x N
        
        x_traj[0, :] = x0
        for idx, tval in enumerate(times_traj[:-1]): 
            current_vel = f_of_txp(tval, x_traj[idx, :], params) 
            x_traj[idx+1, :] = x_traj[idx, :] + dt_max * current_vel
    return x_traj, times_traj

## Silly example testing the generalic trajectory function: 
## - launch particle, only force is gravity $d v_y / dt = -g$

In [None]:
# Note: keep signature of function fixed for now as x, t, p - state [dim N], times [scalar], params [dim p]
# TODO jit all func, or torch for eventual Autodiff

def dxdt_particle(t, x, params):
    """
    # Throw a ball in R^2
    # Three params: vx0, vy0, gravity
    """
    vx0, vy0, g = params
    
    dxdt = np.zeros_like(x)
    dxdt[0] = vx0
    dxdt[1] = vy0 - g*t
    return dxdt #np.array([dxdt, dydt])

# sample call
dxdt_particle(0, [10, 10], [1, 1, 9.8])

In [None]:
solve_ivp(dxdt_particle, [0.0, 4.0], [10, 10], args=[(1, 1, 9.8)], method='Radau', dense_output=True)

In [None]:
x0 = np.array([10, 100])
#x0 = [10, 100]
params = [1, 10, 9.8]
x_traj, times_traj = run_traj(dxdt_particle, x0, params, t0=0.0, tmax=2, dt_max=0.1)

In [None]:
fig, axarr = plt.subplots(2, 1, sharex=True)
pltstyle = dict(linestyle='--', marker='o')
axarr[0].plot(times_traj, x_traj[:,0], **pltstyle); axarr[0].set_title('x vs t') 
axarr[1].plot(times_traj, x_traj[:,1], **pltstyle); axarr[1].set_title('y vs t')
plt.suptitle('Simple example: launch particle with gravity - euler method')
plt.show(); plt.close('all')

# Classic (quadratic) hopfield network with continuous variables $x \in \mathbb{R}^N$

My suggested form
### $\frac{dx}{dt} = -x + \textrm{tanh}(\beta [Jx + b])$

where 

- $\beta=1/T$ is inverse temperature -- at zero temperature limit its digital (sign function)
- $J=\xi \xi^T$ is defined by the $N \times K$ matrix of cell types, with cell type $\mu=1, ..., K$ indexing the cell types represented by columns of $\xi$ 
- $b \in \mathbb{R}^N$ is an $N$ dimensional applied field on specific genes; we will reserve $w \in \mathbb{R}^K$ to denote an analogous forcing applied in the direction of the $K$ encoded cell types

The form in Hopfield, 1984 (PNAS):
### $\frac{dx}{dt} = -x + J \:\textrm{tanh}(\beta x) + b$

#### Remark 1: on representations
The main dynamics stated at the top are not unique. The nonlinearity can be applied before or after the linearity. Here is a general recipe for a recurrent dynamics in $\mathbb^N$ combining an affine transform and a static elementwise nonlinearity $\sigma(z)$:
- Representation 1:  $dx/dt = -x + \sigma(Ax + b)$
- Representation 2:  $dy/dt = -y + A\sigma(y) + b$

I tend to prefer Option 1, but they are sometimes treated interchangeably in the literature. Consider specifically the linear change of variables $y=Ax+b$. When $A$ is invertible, the two representations are equivalent. Otherwise, the situation is more subtle: 
- Going from Rep (1) to Rep (2) works even when $A$ is not full rank. 
- However, if $A$ is not full rank, starting from Rep (2) one comes to $A\: dx/dt = A (-x + \sigma(Ax+b))$. Consider the thin SVD of $A$, corresponding to the first sum in $A= \sum_{i=1}^p \sigma_i u_i v_i^T + \sum_{i=p+1}^N 0 \cdot u_i v_i^T$, accounting for the $p<N$ non-zero singular values ($p = N$ would imply invertible). In matrix form $A=UDV^T$ with $U, V \in \mathbb{R}^{N \times p}$. Left multiplying by the pseudoinverse for $U$ and inverse for $D$ gives the reduced dynamics for $m=V^T x$ in $p<N$ dimensions:  $dm/dt = -m + V^T \:\sigma(U D m + b)$ -- is it interesting that this $p$-dim dynamics is equivalent to the $N>p$ dim dynamics? 

I also tend to use the nonlinearity $\sigma(z) = \textrm{tanh}(\beta z)$. This choice is inspired by the mean-spin flip update rule of the binary Ising model with (stochastic) Glauber dynamics. Combined with Option 1, it has a nice interpretation that any fixed point lies in or on the $\pm 1$ hypercube ($\equiv \Omega_N$); more generally, note that $\Omega_N$ is positively invariant for this choice of $\sigma(z)$. To "think outside the box", one could consider alternative nonlinearities like $\sigma(z)=\textrm{ReLU}(z)$. 

#### Remark 2: Alternative RBM-like two-layer form see [Krotov and Hopfield, 2021, ICLR] (TODO q: when is it equivalent?) :
- $\tau_v\: \frac{dv}{dt} = -v + W_v \,f(h) + b$
- $\tau_h\: \frac{dh}{dt} = -h + W_h \,g(v) + w$

Notation: 

- Let $W_v \in \mathbb{R}^{N \times K}$ be a matrix of connections from "hidden" units $h \in \mathbb{R}^K$ to "visible" units $x \in \mathbb{R}^N$. Let $W_h \in \mathbb{R}^{K \times N}$ denote reverse connections from the visible to the hidden units. The authors assume symmetric connections between the layers, i.e. $W_v=\xi$ and $W_h=\xi^T$. 
- $f:\mathbb{R}^K \rightarrow \mathbb{R}^N$, $g:\mathbb{R}^N \rightarrow \mathbb{R}^K$ are potentially nonlinear functions of a given neuron (e.g. $g_{i} \equiv \textrm{tanh}(v_{i})$  or the entire layer (e.g. in the case of $\textrm{softmax}$). 

In the 2021 ICLR paper for classical quadratic HN they use 
- $f(h)=h \:\:$  i.e. linear
- $g(v)=\textrm{sgn}(v)\approx \textrm{tanh}(\beta v)$
- 
##### Comment on relative timescales for the two-layer dynamics
If one timescale ($\tau_v, \tau_h$) is much faster (i.e. small $\tau$), then the associated dynamics can be eliminated (treated as being at "quasi-steady-state"). In that case, the coupled dynamics can be viewed as a "singular expansion" of the uncoupled form. 

E.g., if $\tau_h \ll \tau_v$, then $h(t) \approx \xi^T g(v) + w$ at all times (since any deviation will rapidly be corrected, relatively speaking).

As a concrete example, consider the choice of $f$, $g$ above mapping to a classic hebbian Hopfield network. Assuming  $\tau_h \ll \tau_x$ and making the QSS substitution gives the condensed dynamics $\tau_v\: dv/dt = -v + \xi \:\textrm{tanh} \left[ \beta (\xi^T v + w) \right] + b$.

#### Additional notes from recent [Krotov and Hopfield, 2021, ICLR]
TODO...

#### Refs:
- [Krotov and Hopfield, 2021, ICLR](https://arxiv.org/pdf/2008.06996)
- [Krotov and Hopfield, 2019, PNAS](https://www.pnas.org/doi/abs/10.1073/pnas.1820458116)

  

In [None]:
def build_xi_rolling_correlated(N, p=3):
    """
    N: number of visible units
    p: number of patterns (column vectors)
    """
    assert N % p == 0
    a = -1 * np.ones(N)
    a[0:p] = 1
    
    xi = np.zeros((N, p))
    for idx in range(p):
        xi[:, idx] = np.roll(a, p * idx, axis=None)
    return xi

def build_xi_random_binary(N, p, seed=0):
    """
    N: number of visible units
    p: number of patterns (column vectors)
    """
    np.random.seed(seed)
    xi = 2*np.random.randint(2, size=(N,p)) - 1
    return xi

In [None]:
N_rand, K_rand = 5000, 4
xi_rand = build_xi_random_binary(N_rand, K_rand)
print('Corr. matrix for K=%d random patterns in N=%d dims' % (K_rand, N_rand))
print(xi_rand.T @ xi_rand / N_rand)

print('\nLow dim correlated example (N=%d, K=%d)' % (K_rand, N_rand))
N_simple, K_simple = 9, 3
xi_9_3_corr = build_xi_rolling_correlated(N_simple, p=K_simple)
#print(xi_9_3_corr)
print(xi_9_3_corr.T @ xi_9_3_corr / xi_9_3_corr.shape[0])

In [None]:
# Note: keep signature of function fixed for now as x, t, p - state [dim N], times [scalar], params [dim p]
# TODO jit all func, or torch for eventual Autodiff

#@torch.jit
def dxdt_HN_quadratic_hebb_A(t, x, params):
    """
    - My preferred form: arises from mean spin flip rule for stochastic discrete updates
    - This form works roughly as expected
    """
    beta, xi = params
    xi = xi / np.sqrt(xi.shape[0])  # local normalize
    arg_of_sigma = xi @ xi.T @ x
    dxdt = -x + np.tanh(beta * arg_of_sigma)
    return dxdt

def dxdt_HN_quadratic_hebb_B(t, x, params):
    """
    This form is closer to Hopfield 1984
    - Also Eq. (24) of K+H, 2021, ICLR 
    - See spurious minima when point in vicinity of pattern 0 is given small linear push towards other other patterns
    """
    beta, xi = params
    xi = xi / np.sqrt(xi.shape[0])  # local normalize
    arg_of_sigma = x
    dxdt = -x + xi @ xi.T @ np.tanh(beta * arg_of_sigma)
    return dxdt

def dxdt_HN_quadratic_hebb_C(t, x, params):
    """
    Issues with this form, which arises from RBM-type interpretation  
    - see Eq. (1) of K+H, 2021, ICLR 
    - try to match Sec 3.1 and App. B of K+H, 2021, ICLR 
    - TODO check implementation and issues...
    
    Here x is a dim N + dim K - coupled two-layer RBM style dynamics
    
    In the QSS limit of fast h dynamics, this reduces to the form in dxdt_HN_quadratic_hebb_B(...)
    """
    beta, xi = params
    dim_N, dim_K = xi.shape
    xi = xi / np.sqrt(xi.shape[0])  # local normalize
    
    state_v = x[:dim_N]
    state_h = x[dim_N:]
    
    f_mu_of_h = state_h                 # see Eq. (27)
    g_i_of_v = np.tanh(beta * state_v)  # like sign function
    
    dxdt = np.zeros_like(x)    
    
    # v dynamics - N dim
    dxdt[:dim_N] = -state_v + xi @ f_mu_of_h  # TODO add applied fields?
    # v dynamics - K dim
    dxdt[dim_N:] = -state_h + xi.T @ g_i_of_v 
    return dxdt


def dxdt_HN_quadratic_hebb_D(t, x, params):
    """
    Issues with this form, which arises from RBM-type interpretation  
    - try to match Sec 3.1 and App. B of K+H, 2021, ICLR 
    Here x is dim K - integrate out the visible units, only slow hidden dynamics 
    - corresponds to h_mu variables; assume v(t) = xi.T f(h) + b at all times 
    Form is as below with tau and applied fields dropped: 
        dv/dt = -v + xi tanh beta (xi^T v + w) + b.
    """
    beta, xi = params
    dim_N, dim_K = xi.shape
    #xi = xi / np.sqrt(dim_N)  # local normalize
    
    state_h = x
    f_mu_of_h = state_h                 # see Eq. (27)

    state_v = xi @ f_mu_of_h            # QSS assumption; omit +b (applied field, N dim)
    g_i_of_v = np.tanh(beta * state_v)  # like sign function
    
    # h dynamics - K dim    
    dhdt = -state_h + xi.T @ g_i_of_v / dim_N  # / np.sqrt(xi.shape[0])   # dx/dt = -x + xi.T tanh( xi @ h )

    return dhdt

# sample call
local_beta = 10.0
local_xi = xi_rand
local_N, local_K = local_xi.shape
params = (local_beta, local_xi)

Get trajectories and plot the projection onto pattern subspace + norm of vector field

Note: x is used as shorthand for state; it can mean: 
- just v (dim N); or 
- just h (dim K); or 
- v + h  (dim N+K)

In [None]:
np.random.seed(0)
v0 = np.random.rand(local_N)

h0 = local_xi.T @ v0 / local_N  # Option 1: target vector based on visible layer
#h0 = np.zeros(local_K)           # Option 2: zero vector

vh0 = np.concatenate((v0, h0))
print('h0 (overlaps):', h0)
print('norm of h0:', np.linalg.norm(h0))
print(vh0.shape)

In [None]:
methods_foo = [dxdt_HN_quadratic_hebb_A, dxdt_HN_quadratic_hebb_B, dxdt_HN_quadratic_hebb_C, dxdt_HN_quadratic_hebb_D]#, dxdt_HN_quadratic_hebb_D]
n_methods = len(methods_foo)

methods_str = ['(A - dim N)', '(B - dim N)', '(C - dim N+K)', '(D - dim K)']
methods_x0 = [v0, v0, vh0, h0]
methods_traj_x = [0] * n_methods
methods_traj_t = [0] * n_methods

for idx in range(n_methods):
    print('Working on traj %d (%d total)...' % (idx, n_methods))
    x_traj, times_traj = run_traj(methods_foo[idx], methods_x0[idx], params, t0=0.0, tmax=4, dt_max=0.1)
    methods_traj_x[idx] = x_traj
    methods_traj_t[idx] = times_traj

In [None]:
#methods_foo[0](methods_traj_x[idx].T, 0, params)

#fig, axarr = plt.subplots(3, 2, sharex=True, figsize=(8,9))
fig, axarr = plt.subplots(n_methods, 2, sharex=True, figsize=(8, 3*n_methods))
pltstyle = dict(linestyle='-', marker='o', markersize=3)
pltstyle_h_mu = dict(linestyle='--', marker='^', markersize=6, markerfacecolor='None', alpha=0.75, zorder=10)

for idx in range(n_methods):
    if methods_traj_x[idx].shape[1] == local_K:
        axarr[idx, 0].plot(methods_traj_x[idx],  **pltstyle_h_mu)  # in this case, we are plotting the hidden/memory variables directly
    else:
        axarr[idx, 0].plot(methods_traj_x[idx][:, :local_N] @ params[1] / local_N,  **pltstyle)
        if methods_traj_x[idx].shape[1] == (local_N + local_K):
            # also plot hidden dim directly...
            cc = [line_obj.get_c() for line_obj in axarr[idx, 0].get_lines()]
            axarr[idx, 0].set_prop_cycle('color', cc)
            axarr[idx, 0].plot(methods_traj_x[idx][:, :local_N] @ params[1] / local_N,  **pltstyle_h_mu)
    axarr[idx, 0].set_title('HN classic quadratic %s' % methods_str[idx])
    axarr[idx, 0].axhline(0, linestyle='--', linewidth=2, c='k')
    
    dxdt_arr = methods_foo[idx](0, methods_traj_x[idx].T, params)
    axarr[idx, 1].plot(np.linalg.norm(dxdt_arr, axis=0),  **pltstyle)
    axarr[idx, 1].set_title('Norm of $dx/dt$ for %s' % methods_str[idx])
    axarr[idx, 1].axhline(0, linestyle='--', linewidth=2, c='k')
    
axarr[-1, 0].set_xlabel('t')
axarr[-1, 1].set_xlabel('t')
plt.suptitle('Example: HN classic quadratic\n' + r'Trajectory $N^{-1} \xi^T x(t)$ from random IC (N=%d, K=%d)' % (local_N, local_K))
plt.savefig(NB_OUTPUT + os.sep + 'HN-traj.pdf')
plt.show(); plt.close('all')

### Classic HN variants: Repeat but for non-random IC (in vicinity of pattern 0)

In [None]:
np.random.seed(0)
coeffs_rand = np.random.normal(loc=0, scale=0.2, size=local_K)
print(coeffs_rand)

v0_near_pattern_0 = params[1][:, 0] + coeffs_rand @ params[1][:, :].T
print('overlaps:', params[1].T @ v0_near_pattern_0 / local_N)

h0 = local_xi.T @ v0_near_pattern_0 / local_N  # Option 1: target vector based on visible layer
#h0 = np.zeros(local_K)           # Option 2: zero vector

v0_near_pattern_0_attach_h0 = np.concatenate((v0_near_pattern_0, h0))
print('h0 (overlaps):', h0)
print('norm of h0:', np.linalg.norm(h0))
print(v0_near_pattern_0_attach_h0.shape)

In [None]:
methods_traj_x = [0] * n_methods
methods_traj_t = [0] * n_methods
methods_x0 = [v0_near_pattern_0, v0_near_pattern_0, v0_near_pattern_0_attach_h0, h0]

for idx in range(n_methods):
    x_traj, times_traj = run_traj(methods_foo[idx], methods_x0[idx], params, t0=0.0, tmax=4, dt_max=0.1)
    methods_traj_x[idx] = x_traj
    methods_traj_t[idx] = times_traj

In [None]:
#fig, axarr = plt.subplots(3, 2, sharex=True, figsize=(8,9))
fig, axarr = plt.subplots(n_methods, 2, sharex=True, figsize=(8, 3*n_methods))
pltstyle = dict(linestyle='-', marker='o', markersize=3)
pltstyle_h_mu = dict(linestyle='--', marker='^', markersize=6, markerfacecolor='None', alpha=0.75, zorder=10)

for idx in range(n_methods):
    if methods_traj_x[idx].shape[1] == local_K:
        axarr[idx, 0].plot(methods_traj_x[idx],  **pltstyle_h_mu)  # in this case, we are plotting the hidden/memory variables directly
    else:
        axarr[idx, 0].plot(methods_traj_x[idx][:, :local_N] @ params[1] / local_N,  **pltstyle)
        if methods_traj_x[idx].shape[1] == (local_N + local_K):
            # also plot hidden dim directly...
            cc = [line_obj.get_c() for line_obj in axarr[idx, 0].get_lines()]
            axarr[idx, 0].set_prop_cycle('color', cc)
            axarr[idx, 0].plot(methods_traj_x[idx][:, :local_N] @ params[1] / local_N,  **pltstyle_h_mu)
    axarr[idx, 0].set_title('HN classic quadratic %s' % methods_str[idx])
    axarr[idx, 0].axhline(0, linestyle='--', linewidth=2, c='k')
    
    dxdt_arr = methods_foo[idx](0, methods_traj_x[idx].T, params)
    axarr[idx, 1].plot(np.linalg.norm(dxdt_arr, axis=0),  **pltstyle)
    axarr[idx, 1].set_title('Norm of $dx/dt$ for %s' % methods_str[idx])
    axarr[idx, 1].axhline(0, linestyle='--', linewidth=2, c='k')
    
axarr[-1, 0].set_xlabel('t')
axarr[-1, 1].set_xlabel('t')
plt.suptitle('Example: HN classic quadratic\n' + r'Trajectory $N^{-1} \xi^T x(t)$ from IC perturbed from $\xi^{(0)}$ (N=%d, K=%d)' % (local_N, local_K))
plt.savefig(NB_OUTPUT + os.sep + 'HN-traj_near_xi0.pdf')
plt.show(); plt.close('all')

### Now let's implement the modern Hopfield network vector field  
$dx/dt = ...$

In [None]:
# Note: keep signature of function fixed for now as x, t, p - state [dim N], times [scalar], params [dim p]
# TODO jit all func, or torch for eventual Autodiff

def dxdt_mhn_A(t, x, params):
    """
    # Throw a ball in R^2
    # Three params: vx0, vy0, gravity
    """
    vx0, vy0, g = params
    dxdt = vx0
    dydt = vy0 - g*t
    return np.array([dxdt, dydt])

# sample call
dxdt_mhn_A(0, [10,10], [1, 1, 9.8])

In [None]:
'''
x0 = np.array([10,100])
params = [1, 10, 9.8]
x_traj, times_traj = run_traj(dxdt_particle, x0, params, t0=0.0, tmax=2, dt_max=0.1)'''

In [None]:
'''fig, axarr = plt.subplots(2, 1, sharex=True)
pltstyle = dict(linestyle='--', marker='o')
axarr[0].plot(times_traj, x_traj[:,0], **pltstyle); axarr[0].set_title('x vs t') 
axarr[1].plot(times_traj, x_traj[:,1], **pltstyle); axarr[1].set_title('y vs t')
plt.suptitle('Simple example: launch particle with gravity - euler method')
plt.show(); plt.close('all')'''