In [1]:
import os, sys
currentdir = os.path.dirname(os.path.abspath(os.getcwd()))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, currentdir) 
sys.path.insert(0, parentdir) 
sys.path.insert(0, currentdir + "\Code") 

### Calculate basins of attraction under weight noise

In [2]:
import math
import pickle
import numpy as np
import sklearn.decomposition
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib import cm
from matplotlib import rc
import matplotlib as mpl
import matplotlib.lines as mlines
from matplotlib.ticker import LinearLocator
from tqdm import tqdm
import scipy
from scipy.integrate import odeint, DOP853, solve_ivp
from scipy.stats import special_ortho_group
from itertools import chain, combinations, permutations
import seaborn as sns

from ring_functions_noorman import *

"""cmap = 'gray'
rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
rc('text', usetex=True)"""

"cmap = 'gray'\nrc('font', **{'family': 'serif', 'serif': ['Computer Modern']})\nrc('text', usetex=True)"

In [3]:
# symmetric cosine weight matrix W sym jk = JI + JE cos(theta_j - theta_K)
# where JE and JI respectively control the strength of the tuned and untuned components of recurrent connectivity between neurons with preferred headings theta_j and theta_k.

# For a network of size N , there are N 3 such “optimal” values of local excitation J*E

# The parameters (JI, JE) can be set such that this system will generate a population profile that qualitatively looks like a discretely sampled “bump” of activity.
# (JI, JE) are within the subset  \Omega = \OmegaJI\times\OmegaJE \subset (−1, 1) \times (2,1)

def get_corners(N, m):
    #works for even N
    corners = []
    corner_0 = np.array([m]*N)
    corner_0[int(N/2):] *= -1
    corner_0[int(N/2)-int(N/4):int(N/2)] = 0
    corner_0[N-int(N/4):] = 0
    for support_j in range(N):
        corners.append(np.roll(corner_0, support_j))
    corners = np.array(corners)
    return corners

def get_bumps_along_oneside_ring(N, m, corners, step_size=0.1):
    x = np.arange(0, m+step_size, step_size)
    n_xs = x.shape[0]
    bumps = np.zeros((N, n_xs))
    for i, x_i in enumerate(x):
        for j in range(N):
            bumps[j,i] = np.interp(x_i, [0,m], [corners[0][j],corners[1][j]])
    return bumps

def get_all_bumps(N, bumps):
    all_bumps = []
    for support_j in range(N):
        for bump_i in range(bumps.shape[1]):
            all_bumps.append(np.roll(bumps[:,bump_i], support_j))
    all_bumps = np.array(all_bumps)
    return all_bumps

def get_all_bumps_2darray(N, bumps):
    all_bumps = np.zeros((N,bumps.shape[1],N))
    for support_j in range(N):
        for bump_i in range(bumps.shape[1]):
            all_bumps[support_j,bump_i] = np.roll(bumps[:,bump_i], support_j)
    return all_bumps

def get_noorman_symmetric_weights(N, J_I = 1, J_E = 1):
    # W sym jk = JI + JE cos(theta_j - theta_K)
    x = np.arange(0,N,1)
    row = J_I + J_E*np.cos(2*np.pi*x/N)
    W = scipy.linalg.circulant(row)
    return W


# W asym jk =sin(theta_j - theta_k)
def get_noorman_asymmetric_weights(N):
    # W asym jk =sin(theta_j - theta_k)
    x = np.arange(0,N,1)
    row = np.sin(2*np.pi*x/N)
    W = scipy.linalg.circulant(row)
    return W

def noorman_ode(t,x,tau,transfer_function,W_sym,W_asym,c_ff,N):
    """Differential equation of head direction network in Noorman et al., 2022. 
    tau: integration constant
    transfer_function: each neuron transforms its inputs via a nonlinear transfer function
    W_sym, W_asym: symmetric and asymmetric weight matrices
    v_in: input
    c_ff: a constant feedforward input to all neurons in the network
    N: number of neurons in the network
    """

    return (-x + np.dot(W_sym+v_in(t)*W_asym, transfer_function(x))/N + c_ff)/tau

def noorman_ode_with_noise(t,x,tau,transfer_function,W_sym,W_asym,c_ff,N,noises):
    """Differential equation of head direction network in Noorman et al., 2022. 
    tau: integration constant
    transfer_function: each neuron transforms its inputs via a nonlinear transfer function
    W_sym, W_asym: symmetric and asymmetric weight matrices
    v_in: input
    c_ff: a constant feedforward input to all neurons in the network
    N: number of neurons in the network
    """
    if noises is not None:
        return (-x + np.dot(W_sym+v_in(t)*W_asym, transfer_function(x))/N + c_ff + noises[round(t)])/tau
    else:
        return (-x + np.dot(W_sym+v_in(t)*W_asym, transfer_function(x))/N + c_ff)/tau

#Bump perturbations
def noorman_ode_pert(t,x,tau,transfer_function,W_sym,W_asym,c_ff,N,center,rotation_mat,amplitude,b):
    """
    create ODE for Noorman ring attractor with a local bump perturbation
    center,rotation_mat,amplitude,b are set
    """
    vector_bump = bump_perturbation(x, center, rotation_mat, amplitude, b)
    noor = noorman_ode(t,x,tau,transfer_function,W_sym,W_asym,c_ff,N)
    return noor + vector_bump

def noorman_ode_Npert(t,x,tau,transfer_function,W_sym,W_asym,c_ff,N,Nbumps):
    """
    create ODE for Noorman ring attractor with Nbumps local bump perturbations
    for each bump: center,rotation_mat,amplitude,b are random
    """
    noorode = noorman_ode(t,x,tau,transfer_function,W_sym,W_asym,c_ff,N)
    for bi in range(Nbumps):
        bump_i = np.random.randint(bumps.shape[0]) 
        roll_j = np.random.randint(N)
        center = np.roll(bumps[:,bump_i], roll_j).copy()
        rotation_mat = special_ortho_group.rvs(N)
        amplitude = np.random.rand()
        b = np.random.rand()
        noorode += bump_perturbation(x, center, rotation_mat, amplitude, b)

    return noorode

# Fixed points and their stabilities
def noorman_jacobian(x, W_sym):
    N = W_sym.shape[0]
    
    r = np.where(x>0)
    W_sub = np.zeros((N,N))
    W_sub[:,r] = W_sym[:,r]
    J = -np.eye(N)
    J += W_sub/N
    return J

def powerset(iterable):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

def noorman_fixed_points(W_sym, c_ff):
    """
    Takes as argument all the parameters of the recurrent part of the model (W_sym, c_ff)
    \dot x = -x + 1/N W_sym ReLU(x) + c_ff = 0
    """
    fixed_point_list = []

    N = W_sym.shape[0]
    subsets = powerset(range(N))
    for support in subsets:
        if support == ():
            continue
        r = np.array(support)
        
        W_sub = np.zeros((N,N))
        W_sub[:,r] = W_sym[:,r]
        A = W_sub/N - np.eye(N)
        fixed_point = -np.dot(np.linalg.inv(A), np.ones(N)*c_ff)
        
        #check true fixed point
        negativity_condition = True
        # print(r, [item for item in range(N) if item not in r])
        for i in r:
            if fixed_point[i] <= 0:
                negativity_condition = False
        for i in [item for item in range(N) if item not in r]:
            if fixed_point[i] >= 0:
                negativity_condition = False
        
        if negativity_condition:
            fixed_point_list.append(fixed_point)
        
    fixed_point_array = np.array(fixed_point_list)
    return fixed_point_array




def bump_perturbation(x, center, rotation_mat, amplitude, b=1):
    """
    Perturbation is composed of parallel vector field 
    with the location given by center, 
    the norm of the vectors determined by a bump function
    and the orientation given by theta
    
    x.shape = (Numberofpoints,N)
    rotation_mat: orientation of perturbation
    implemented for N-dimensional systems
    """
    N = x.shape[0]
    vector_bump = np.zeros(N)
    vector_bump[0] = 1.
    rotation_mat = special_ortho_group.rvs(N)
    vector_bump = np.dot(vector_bump, rotation_mat)
    vector_bump = np.multiply(vector_bump, bump_function(x, center=center, amplitude=amplitude, b=b))
    
    return vector_bump

# we will take phi(·) to be threshold linear

def v_in(t):
    return 0

def ReLU(x):
    return np.where(x<0,0,x)


# # circle_in_vectorspace = 

# # #create points on hypersphere as:
# # 1. take p=[1, 0, ... 0]
# # 2. rotate point(s) along next axis ????

# def fibonacci_sphere(samples=1000):

#     points = []
#     phi = math.pi * (math.sqrt(5.) - 1.)  # golden angle in radians

#     for i in range(samples):
#         y = 1 - (i / float(samples - 1)) * 2  # y goes from 1 to -1
#         radius = math.sqrt(1 - y * y)  # radius at y

#         theta = phi * i  # golden angle increment

#         x = math.cos(theta) * radius
#         z = math.sin(theta) * radius

#         points.append((x, y, z))

#     return points

def hypersphere_lattice(n_points, n_dim):
    """
    Generate a lattice of points on an N-dimensional hypersphere.
    """
    points = np.random.uniform(-1, 1, size=(n_points, n_dim))
    norms = np.linalg.norm(points, axis=1)
    normalized_points = points / norms[:, np.newaxis]
    return normalized_points

def uniform_hypersphere_points(n_points, n_dim):
    """
    Generate a lattice of points on an N-dimensional hypersphere.
    """
    points = np.zeros((n_points, n_dim))
    for i in range(n_points):
        vec_0 = np.zeros((n_dim))
        vec_0[0] = 1
        points[i, :] = np.dot(special_ortho_group.rvs(n_dim), vec_0)

    return points



In [4]:
np.random.seed(1331)

tau = 1
transfer_function = ReLU
N = 20
J_I = -2.4
J_E = 5
W_sym = get_noorman_symmetric_weights(N, J_I, J_E)
W_asym = get_noorman_asymmetric_weights(N)
c_ff = 1.

maxT = 100
n_timesteps = 100
t = np.linspace(0, maxT, n_timesteps)

y0 = np.random.uniform(0,1,N)

sol = solve_ivp(noorman_ode, y0=y0,  t_span=[0,maxT], t_eval=t, args=tuple([tau, transfer_function, W_sym, W_asym, c_ff, N]),dense_output=True)

m = np.max(sol.sol(t)) # m #round? what should the maximum be according to the paper?
corners = get_corners(N, m)
bumps = get_bumps_along_oneside_ring(N, m, corners, step_size=0.05)
step_size = .1
x = np.arange(0, m+step_size, step_size)
all_bumps = get_all_bumps(N, bumps)
all_bumps_2d = get_all_bumps_2darray(N,bumps)

sns.set_context("notebook", font_scale=1.25, rc={"lines.linewidth": 1})
fig, axs = plt.subplots(1, 1, figsize=(15, 8), sharex=True, sharey=True)
axs.imshow(sol.sol(t))
axs.set_xlabel("Time")
axs.set_ylabel("Neurons")

In [None]:
def get_trajectories(starting_points,W_sym_with_noise):
    sols = np.zeros((len(starting_points), t.shape[0], N))
    noises=None
    for i, starting_point in enumerate(starting_points):
        sol = solve_ivp(noorman_ode_with_noise, y0=starting_point,  t_span=[0,maxT], t_eval=t, args=tuple([tau, transfer_function, W_sym_with_noise, W_asym, c_ff, N, noises]),dense_output=True)
        sols[i] = sol.sol(t).T
    return sols

In [None]:
def get_trajectories_per_sphere_radius(base_sphere,radius):
    sphere = base_sphere * radius
    seeds = [27, 0, 3, 13, 418, 1550] #33->0
    norms =  [1e-10, 1e-5, 1e-1, .5, 1, 1.1, 1.5]

    all_trajectories = {}
    for seed in seeds:
        for norm in norms:
            np.random.seed(seed)
            eps = np.random.uniform(-1,1,(N,N))
            eps /= np.linalg.norm(eps)
            eps *= norm

            fixed_points = noorman_fixed_points(W_sym+eps, c_ff)
            fixed_points = fixed_points[:-1] # exclude last fixed point which is not on the ring

            trajectories_per_network = [] # shape (n_fixed_points,n_sphere_points,timesteps,N)
            # for each fixed point, create a sphere of points around it, from which to start trajectories
            for fixed_point in fixed_points:
                starting_points = fixed_point + sphere
                trajectories_per_fixed_point = get_trajectories(starting_points,W_sym+eps)
                trajectories_per_network.append(trajectories_per_fixed_point)
        
            all_trajectories[(seed,norm)] = np.array(trajectories_per_network)
    with open("results//basin_of_attraction//trajectories_{}.pkl".format(radius), "wb") as f:
        pickle.dump(all_trajectories, f)

In [None]:
n_dim = N
n_sphere_points = 2
base_sphere = uniform_hypersphere_points(n_sphere_points, n_dim)

for radius in [1000000]:
    print(radius)
    get_trajectories_per_sphere_radius(base_sphere,radius)

## Analyze trajectories

In [None]:
seeds = [27, 0, 3, 13, 418, 1550] #33->0
norms =  [1e-10, 1e-5, 1e-1, .5, 1, 1.1, 1.5]
for radius in range(20):
    with open("results//basin_of_attraction//trajectories_{}.pkl".format(radius), "rb") as f:
        trajectories_all_seeds_and_norms = pickle.load(f)
    with open("results//weight_noise//N6//{}//fixed_points.pkl".format(J_E), "rb") as f:
        fixed_points_all = pickle.load(f)
    with open("results//weight_noise//N6//{}//stable_points.pkl".format(J_E), "rb") as f:
        stable_points_all = pickle.load(f)
    with open("results//weight_noise//N6//{}//saddle_points.pkl".format(J_E), "rb") as f:
        saddle_points_all = pickle.load(f)

    for seed in seeds:
        for norm in norms:
            # trajectories_per_network shape (n_fixed_points,n_starting_points,timesteps,N)
            trajectories_per_network = trajectories_all_seeds_and_norms[(seed,norm)]
            # fixed_points shape (n_fixed_points,N)
            fixed_points = fixed_points_all[(seed,norm)]
            stable_points = stable_points_all[(seed,norm)]
            saddle_points = saddle_points_all[(seed,norm)]
            n_fixed_points = fixed_points.shape[0]

            for i in range(n_fixed_points):
                fixed_point = fixed_points[i]
                if fixed_point in stable_points:
                    # trajectories_per_fixed_point shape (n_starting_points,timesteps,N)
                    trajectories_per_fixed_point = trajectories_per_network[i]
                    # dist shape (n_starting_points,timesteps)
                    """
                    # check whether trajectory location is approaching fixed point. This is not ideal, because although the trajectory starting point was defined based on a sphere around this fixed point, 
                    # the sphere radius might be big enough that the starting point is closest to another fixed point or it just goes towards its normal projection to the manifold, not necessarily to a fixed point
                    dist = np.linalg.norm(trajectories_per_fixed_point-fixed_point,axis=-1) 
                    """
                    speed = np.linalg.norm(trajectories_per_fixed_point[:,1:,:] - trajectories_per_fixed_point[:,:-1,:],axis=2) # check whether trajectory speed is decreasing
                    if speed[-1] - speed[0] >= 0:
                        print("Trajectory speed not decreasing for seed {} norm {} fixed point index {}".format(seed,norm,i))
                    # optional: save all speeds?

In [None]:
def plot_trajectory(radius,seed,norm,fixed_point_idx,starting_point_idx,intermediate_timestep,sols):

    with open("results//basin_of_attraction//trajectories_{}.pkl".format(radius), "rb") as f:
        trajectories_all_seeds_and_norms = pickle.load(f)
    with open("results//weight_noise//N6//{}//fixed_points.pkl".format(J_E), "rb") as f:
        fixed_points_all = pickle.load(f)
    with open("results//weight_noise//N6//{}//stable_points.pkl".format(J_E), "rb") as f:
        stable_points_all = pickle.load(f)
    with open("results//weight_noise//N6//{}//saddle_points.pkl".format(J_E), "rb") as f:
        saddle_points_all = pickle.load(f)

    trajectories_per_fixed_point = trajectories_all_seeds_and_norms[(seed,norm)][fixed_point_idx][starting_point_idx]
    fixed_points = fixed_points_all[(seed,norm)]
    stable_points = stable_points_all[(seed,norm)]
    saddle_points = saddle_points_all[(seed,norm)]
    print(trajectories_per_fixed_point.shape)

    pca = sklearn.decomposition.PCA(n_components=2)

    fixed_points_proj = pca.fit_transform(fixed_points.reshape(-1,N))
    corners_proj = pca.transform(corners)

    final_trajectory_state_proj = pca.transform(trajectories_per_fixed_point[-1].reshape(-1,N))
    initial_trajectory_state_proj = pca.transform(trajectories_per_fixed_point[0].reshape(-1,N))
    intermediate_trajectory_state_proj = pca.transform(trajectories_per_fixed_point[intermediate_timestep].reshape(-1,N))

    for i in range(N):
        plt.plot([corners_proj[i-1,0], corners_proj[i,0]],
                [corners_proj[i-1,1], corners_proj[i,1]],
                'k', label="Original attractor", zorder=0, alpha=0.1, linewidth=10, 
                solid_capstyle='round')
        
    for i in range(fixed_points.shape[0]):
        if fixed_points[i] in stable_points:
            plt.plot(fixed_points_proj[i,0], fixed_points_proj[i,1], '*', color="darkblue", label="Stable", zorder=10, alpha=1., markersize=5) # final states of the trajectories 
        elif fixed_points[i] in saddle_points:
            plt.plot(fixed_points_proj[i,0], fixed_points_proj[i,1], '*', color="darkorange", label="Stable", zorder=10, alpha=1., markersize=5) # final states of the trajectories 

        
    #plt.plot(fixed_point_proj[:,0], fixed_point_proj[:,1], '^b', label="Stable", zorder=10, alpha=1., markersize=20) # final states of the trajectories 
    plt.plot(initial_trajectory_state_proj[:,0], initial_trajectory_state_proj[:,1], '.g', label="Stable", zorder=10, alpha=1., markersize=10) # final states of the trajectories 
    #plt.plot(intermediate_trajectory_state_proj[:,0], intermediate_trajectory_state_proj[:,1], '.m', label="Stable", zorder=10, alpha=1., markersize=10) # final states of the trajectories 
    plt.plot(final_trajectory_state_proj[:,0], final_trajectory_state_proj[:,1], '.r', label="Stable", zorder=10, alpha=1., markersize=10) # final states of the trajectories 


In [None]:
seed=seeds[0]
norm=norms[0]
fixed_point_idx=1
starting_point_idx=5
intermediate_timestep=8
radius=9
plot_trajectory(radius,seed,norm,fixed_point_idx,starting_point_idx,intermediate_timestep,sols[3])
