In [1]:
%matplotlib qt
from scipy.stats import ttest_1samp as pvaluef
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns 
import scipy
from functools import partial

In [2]:
def odeIntegral(x,t,J,I=0):
    x = np.squeeze(x)
    x = np.reshape(x,(len(x),1))
    # print('size:',np.shape(x),np.shape(J@np.tanh(x)))
    dxdt = -x+J@np.tanh(x)
    return np.squeeze(dxdt)
def odesimulation(t,xinit,Jpt,I):
    return scipy.integrate.odeint(partial(odeIntegral,J=Jpt,I=I),xinit,t)

In [5]:
print('mean m:',mstats[0,:,0])
print('mean m2:',mstats[1,:,0])

mean m: [-0.70710678 -0.70710678  0.70710678  0.70710678]
mean m2: [ 0.70710678 -0.70710678 -0.70710678  0.70710678]


In [4]:
### define two unit rank structures
### also there are four populations with mixed selectivty 
Rm, Rn = 1, 2.3
nrank, npop = 2,4
mstats, nstats = np.zeros((nrank,npop,2)), np.zeros((nrank,npop,2))
mnstats = np.zeros((npop,nrank,nrank))
### we assume that the overlap between m and n is zero
Nparams = [250,250,250,250]
N = np.sum(Nparams)

for p in range(npop):
    ### mean 
    # mstats[0,p,0], nstats[0,p,0] = Rm*np.cos(2*np.pi*(p+1)/4), Rn*np.cos(2*np.pi*(p+1)/4)
    # mstats[1,p,0], nstats[1,p,0] = Rm*np.sin(2*np.pi*(p+1)/4), Rn*np.sin(2*np.pi*(p+1)/4)
    
    mstats[0,p,0], nstats[0,p,0] = Rm*np.cos(2*np.pi*(p+1)/4+1*np.pi/4), Rn*np.cos(2*np.pi*(p+1)/4+1*np.pi/4)
    mstats[1,p,0], nstats[1,p,0] = Rm*np.sin(2*np.pi*(p+1)/4+1*np.pi/4), Rn*np.sin(2*np.pi*(p+1)/4+1*np.pi/4)
    
    ### variance (the same)
    mstats[0,p,1], nstats[0,p,1] = np.sqrt(0.1), np.sqrt(0.5)
    mstats[1,p,1], nstats[1,p,1] = np.sqrt(0.1), np.sqrt(0.5)

### generate low-rank networks
mvec, nvec = np.zeros((N,nrank)), np.zeros((N,nrank))
for r in range(nrank):
    nstart, nend = 0,0 
    for p in range(npop):
        nend +=Nparams[p]
        mvec[nstart:nend,r] = mstats[r,p,0]+ mstats[r,p,1]*np.random.normal(0,1,(Nparams[p]))
        nvec[nstart:nend,r] = nstats[r,p,0]+ nstats[r,p,1]*np.random.normal(0,1,(Nparams[p]))
        nstart = nend
for r in range(nrank):
    mvec[:,r] /=np.std(mvec[:,r])

In [358]:
tt  = np.linspace(0,50,1000)
dt  = tt[2]-tt[1]
ntt = len(tt)
xtemporal = np.zeros((4,ntt,N))
Jpt   = (mvec@nvec.T)/N

xinitr = np.random.normal(0,1E-2,(4,N))
xinit0=np.array([[-1,-1,1,1],
        [1,-1,-1,1],
        [1,1,-1,-1],
        [-1,1,1,-1]])
for iinit in range(4):
    xinit = np.squeeze(1*xinitr[iinit,:].copy())
    xinit[:Nparams[0]]+= np.abs(0.5*np.random.randn())*xinit0[iinit,0]
    if iinit==0:
        print(xinit[:250:10])
    for p in range(1,npop):
        xinit[np.sum(Nparams[:p]):np.sum(Nparams[:p+1])]+= np.abs(0.5*np.random.randn())*xinit0[iinit,p]

    # xinit = np.zeros((1,N))#
    xinit = np.squeeze(xinit)
    xtemporal[iinit,:,:] = odesimulation(tt,xinit,Jpt,0)

[-0.04097999 -0.01866378 -0.05651627 -0.03573543 -0.04942766 -0.03420552
 -0.02444833 -0.01426918 -0.02841971 -0.02538111 -0.01564708 -0.02352852
 -0.0203076  -0.03090263 -0.02161493 -0.04068399 -0.02315862 -0.0209344
 -0.01848782 -0.03589201 -0.03836234 -0.03346646 -0.02763782 -0.02785541
 -0.02953609]


In [359]:
fig,ax = plt.subplots(2,2,figsize=(6,4),tight_layout=True, sharex=True, sharey=True)
for i in range(2):
    for j in range(2):
        idx = (i)*2+j

        ax[i][j].plot(tt,np.mean(xtemporal[idx,:,:Nparams[0]],axis=1),'tab:red')
        ax[i][j].plot(tt,np.mean(xtemporal[idx,:,np.sum(Nparams[:1]):np.sum(Nparams[:2])],axis=1),'tab:green')

        ax[i][j].plot(tt,np.mean(xtemporal[idx,:,np.sum(Nparams[:2]):np.sum(Nparams[:3])],axis=1),'tab:blue')
        ax[i][j].plot(tt,np.mean(xtemporal[idx,:,np.sum(Nparams[:3]):np.sum(Nparams[:4])],axis=1),'tab:purple')

#### Latent dynamics $\kappa_1$ and $\kappa_2$

In [360]:
gaussian_norm = (1/np.sqrt(np.pi))
gauss_points, gauss_weights = np.polynomial.hermite.hermgauss(300)
gauss_points = gauss_points*np.sqrt(2)
def Phi(mu, delta0):
    integrand = np.tanh(mu+np.sqrt(delta0)*gauss_points)
    return gaussian_norm * np.dot (integrand,gauss_weights)

In [372]:
def latent_dyns(kappainit, t, means, sigmas, tau):
    npop, nrank,_ = np.shape(means)
    ntt = len(t)
    kappa_temp = np.zeros((nrank,ntt))
    kappa_temp[:,0] = kappainit.copy()
    dtt = t[1]-t[0]
    mu_temp, delta_temp = np.zeros(npop),np.zeros(npop)
    for p in range(npop):
        mu_temp[p] = means[p,0,0]*kappa_temp[0,0]+means[p,1,0]*kappa_temp[1,0]
        delta_temp[p] = sigmas[p,0,0]**2*(kappa_temp[0,0]**2*kappa_temp[1,0]**2)
    
    for tt in range(ntt-1):
        kappa_new = np.zeros(2)
        for r in range(nrank):
            k_rec = 0#-kappa_temp[r,tt]
            for p in range(npop):
                k_rec +=means[p,r,1]*Phi(mu_temp[p],delta_temp[p])
            k_rec /=npop 
            k_rec -=(kappa_temp[r,tt])
            kappa_new[r]=(k_rec*dtt/tau)+kappa_temp[r,tt]
        kappa_temp[:,tt+1]=kappa_new[:]
        # print(kappa_temp[:,tt+1])
        ### update mu and delta 
        for p in range(npop):
            mu_temp[p] = means[p,0,0]*kappa_temp[0,tt+1]+means[p,1,0]*kappa_temp[1,tt+1]
        delta_temp[p] = sigmas[p,0,0]**2*(kappa_temp[0,tt+1]**2*kappa_temp[1,tt+1]**2)
    return kappa_temp

In [368]:
means = np.zeros((4,2,2))
means[:,:,0] = mstats[:,:,0].T 
means[:,:,1] = nstats[:,:,0].T 

sigmas = np.zeros((4,2,2))
sigmas[:,:,0] = mstats[:,:,1].T 
sigmas[:,:,1] = nstats[:,:,1].T 

t  = np.linspace(0,50,1000)
tau = 1 



[-1.04704564  0.03639826]
[-1.04584579  0.03584172]
[-1.04466492  0.03529499]
[-1.04350267  0.03475785]
[-1.0423587  0.0342301]
[-1.04123266  0.03371153]
[-1.04012423  0.03320196]
[-1.03903309  0.0327012 ]
[-1.03795892  0.03220906]
[-1.03690141  0.03172536]
[-1.03586025  0.03124993]
[-1.03483516  0.03078261]
[-1.03382584  0.03032322]
[-1.03283202  0.0298716 ]
[-1.03185341  0.0294276 ]
[-1.03088975  0.02899106]
[-1.02994078  0.02856183]
[-1.02900622  0.02813977]
[-1.02808584  0.02772473]
[-1.02717938  0.02731658]
[-1.02628661  0.02691518]
[-1.02540727  0.02652039]
[-1.02454115  0.02613209]
[-1.02368802  0.02575015]
[-1.02284764  0.02537445]
[-1.02201981  0.02500487]
[-1.02120431  0.02464129]
[-1.02040093  0.02428359]
[-1.01960946  0.02393167]
[-1.01882971  0.02358541]
[-1.01806148  0.02324471]
[-1.01730458  0.02290946]
[-1.01655881  0.02257956]
[-1.01582399  0.02225491]
[-1.01509995  0.02193542]
[-1.01438649  0.02162098]
[-1.01368346  0.02131151]
[-1.01299067  0.0210069 ]
[-1.01230796  

In [376]:
# kappa_dyns = np.zeros((4,2,ntt))
kappainit = np.array([0+np.random.uniform()*0.3,-1+np.random.uniform()*0.3])
kappa_dyns[3,:,:] = latent_dyns(kappainit, t, means, sigmas, tau)

In [377]:
fig,ax=plt.subplots(figsize=(4,4))
ax.scatter(kappa_dyns[0,0,:],kappa_dyns[0,1,:],c='tab:red')
ax.scatter(kappa_dyns[1,0,:],kappa_dyns[1,1,:],c='tab:green')
ax.scatter(kappa_dyns[2,0,:],kappa_dyns[2,1,:],c='tab:blue')
ax.scatter(kappa_dyns[3,0,:],kappa_dyns[3,1,:],c='tab:purple')

<matplotlib.collections.PathCollection at 0x1c77b5460>