In [None]:

import numpy as np
import scipy.stats as spstat
from dataclasses import dataclass
#from scipy.optimize import root
#from scipy.optimize import fsolve


from joblib import Parallel, delayed

In [None]:
# Class for the CANN model

class cann_model:
    # define the range of perferred stimuli
    z_min = - np.pi;              
    z_range = 2.0 * np.pi;
    # define the time scale
    tau = 2.0
        
    # function for periodic boundary condition
    def dist(self, c):
        tmp = np.remainder(c, self.z_range)
        
        # routine for numbers
        if isinstance(tmp, (int, float)):
            if tmp > (0.5 * self.z_range):
                return (tmp - self.z_range);
            return tmp;
        
        # routine for numpy arraies
        for tmp_1 in np.nditer(tmp, op_flags=['readwrite']):
            if tmp_1 > (0.5 * self.z_range):
                tmp_1[...] = tmp_1 - self.z_range;
        
        return tmp;
    
    # constructor (?)
    def __init__(self, argument):
        self.k = argument.k;              # rescaled inhibition
        self.beta = argument.beta;        # rescaled STD
        self.beta_f = argument.beta_f;        
        self.case = argument.case
        self.taud = argument.taud;        # rescaled STD timescale
        self.a = argument.a;              # range of excitatory connection
        self.N = argument.N;              # number of units / neurons
        self.dx = self.z_range / self.N     # separation between neurons
        
        # define perferred stimuli for each neuron
        self.x = (np.arange(0,self.N,1)+0.5) * self.dx + self.z_min;
        
        # difference of x's
        self.x_diff = np.array(
                       [[self.dist(self.x[i] - self.x[j]) 
                         for j in range(self.x.shape[0])]
                        for i in range(self.x.shape[0])]
                       )
        
        # calculate the excitatory couple for each pair of neurons
        self.Jxx = np.exp(-0.5 * np.square(
                      self.x_diff / self.a
                   ) ) / (np.sqrt(2*np.pi) * self.a);
                
        self.y = np.zeros((self.N + self.N * self.N));   # initialize dynamical variables
        self.y[self.N:] = 1.0
        
        self.beta = np.zeros_like(self.x_diff)

        for i in np.arange(self.Jxx.shape[1]):

            local_x = np.abs(self.x_diff[:,i])

            x_max = np.max(local_x) + (self.z_range/self.N) / 2.0
            x_min = np.min(local_x) - (self.z_range/self.N) / 2.0

            ps = (x_max - local_x) / (x_max - x_min)

            if self.case == 1:
                gamma_randm = spstat.gamma.ppf(ps, 3.354520641938138)*9.743699331037247
            elif self.case == 0:
                gamma_randm = spstat.gamma.ppf(ps, 1.377771974410986)*29.196273404252505
            elif self.case == 2:
                gamma_randm = spstat.gamma.ppf(ps, 1)*40.0

            self.beta[:,i] = gamma_randm.transpose()

        self.beta *= argument.beta / np.mean(self.beta)
        
        if argument.beta_f:
            self.beta = np.ones_like(self.Jxx) * argument.beta
        
# Defining a class of input arguments as a workaround 
# for the argparse

@dataclass()
class argument_c:
    k:float
    beta:float
    beta_f:bool
    case:int
    taud:float
    a:float
    N:int
    
    def _init__(self, N, k, beta, beta_f, case, taud, a):
        self.N = N
        self.k = k
        self.beta = beta
        self.beta_f = beta_f
        self.case = case
        self.taud = taud
        self.a = a
    

In [None]:
def F(x:np.ndarray, the_model):

    
    u0 = x[0]
    p0 = x[1]
    s0 = x[2]
    
    arg_dummy = argument_c(N=the_model.N, k=the_model.k, beta=1, beta_f = the_model.beta_f, 
                           case=the_model.case, taud=the_model.taud, a=0.5)
    dummy = cann_model(arg_dummy)
    
    gauss_4a = np.exp(-0.25 * np.square (the_model.dist(the_model.x - s0) / the_model.a))
    gauss_2a = np.exp(-0.5 * np.square (the_model.dist(the_model.x - s0) / the_model.a))    
    psi = dummy.beta * np.exp(-0.5 * np.square (the_model.x / the_model.a))
    dummy = None
    
    dgauss_4a_dx = np.gradient(gauss_4a, the_model.x)    
    dpsi_dx1 = np.gradient(psi, the_model.x, axis=0)    
    dpsi_dx2 = np.gradient(psi, the_model.x, axis=1)    
    
    u0_old = None
    p0_old = None
    
    for t in range(10000):

        rx = (u0 * u0 / (1 + 0.125 * the_model.k * u0 * u0)) * gauss_2a

        ux = u0 * gauss_4a

        pxx = 1 - p0 * psi
        
        Fu = -ux + np.dot(the_model.Jxx * pxx, rx) * the_model.dx
        
        Fp = 1 - pxx - the_model.beta * pxx * rx
        
        v = - np.sum(dgauss_4a_dx * Fu) / np.sum(dgauss_4a_dx * dgauss_4a_dx) / (1e-10 + the_model.tau * u0)
    
        u0_new = u0 + np.sum(gauss_4a * Fu) / np.sum(gauss_4a * gauss_4a)
        
        p0_new = p0 - 0.05* np.sum(psi * Fp) / np.sum(psi * psi)

        Fv = np.sum((dpsi_dx1 + dpsi_dx2) * Fp) / np.sum((dpsi_dx1 + dpsi_dx2) * (dpsi_dx1 + dpsi_dx2))\
                - the_model.taud * p0 * v

        if t > 100:
                r_u0 = np.abs((u0_new - u0_old) / np.max([1, 1+u0_old]))
                r_p0 = np.abs((p0_new - p0_old) / np.max([1, 1+p0_old]))
                if r_u0 < 1e-8 and r_p0 < 1e-8:
                    break

        u0_old = u0

        p0_old = p0
        
        u0 = u0_new

        p0 = p0_new        

        
    return u0, p0, v, Fv, np.linalg.norm(Fu)/len(Fu), np.linalg.norm(Fp)/len(Fp)
    



In [None]:

def one_round(k_in, beta_in):
    
    output = []
    
    for case in [0,1,2]:

        arg0 = argument_c(N=128, k=k_in, beta=beta_in, beta_f=False, case=case, taud=100, a=0.5)

        the_model0 = cann_model(arg0)

        s = 0
        ds = 0.1


        while ds > 1e-6:

            s = s + ds
            
            u0, p0, v, Fv, Fu, Fp = 0, 0, 0, 0, 0, 0

            u0, p0, v, Fv, Fu, Fp = F((np.sqrt(8.)*1.1/the_model0.k, 0, s), the_model0)

            # print("{:.5f} {:.5f} {:.5f} v={:.5f} {:.5f} {:.5f} {:.5f}".format(ds, u0, p0, v, Fv, Fu, Fp))

        

            if Fv >= 0 :
                s = s - ds
                ds = ds / np.pi

        if np.abs(v) < 1e-10 or np.abs(u0) < 1e-6:
            v = 0

        output.append([u0, v])

    return np.array(output)

N_steps = 100

k_min = 0.01
k_max = 1.00
d_k = k_max / N_steps

beta_min = 0
beta_max = 5e-3
d_beta = beta_max / N_steps

all_k = np.array([k_min + k_i * d_k for k_i in range(N_steps+1)])
all_beta = np.array([beta_min + beta_i * d_beta for beta_i in range(N_steps+1)])

all_data = Parallel(n_jobs=20)(delayed(one_round)(k, beta) for k in all_k for beta in all_beta)







In [None]:
np.savez_compressed("06_solve_v_by_disp_4.npz", data=all_data, k=all_k, beta=all_beta)

In [None]:
one_round(0.5, 0.0008)