In [1]:
import sys
sys.path.insert(0, "../..")

In [2]:
import numpy as np
from matplotlib import pyplot as plt

from module.base.network import Network

from module.components.discrete_gaussian1D import DiscreteGaussian1D
from module.components.discrete_gaussian2D import DiscreteGaussian2D
from module.components.lawrence_dist import LawrenceDist

from module.simulation.meanfield import MeanField
from module.simulation.set_meanfield2 import SetMeanField2

import module.components.CONST as CONST

from module.simulation.meanfield2 import MeanField2

In [3]:
net = Network(3,3,1, [[0,0,0], [2,0,0], [0,2,0], [2,2,0]])
net.set_voltage_config([-2.625876307246630126e-01, -9.355253310870321679e-03, 1.687528930716121478e-02, -1.963316637927263186e-01], 5.020559674550297523e-03)
net.set_voltage_config([0.1,0,-0.02,0.01],0.03)

mf = MeanField(net)

g2 = DiscreteGaussian2D(phase_space_bounds_n=(-5,5), phase_space_bounds_m=(-5,5))
g1 = DiscreteGaussian1D(phase_space_min=-5, phase_space_max=5)

In [4]:
neighbour_table = net.get_nearest_neighbours(np.arange(0, net.N_particles))

covs = np.zeros((net.N_particles, 6))
def get_cov(i, j):
    table_index = np.where(neighbour_table[i] == j)[0]
    if table_index.shape[0] == 0:
        return 0

    return covs[i, table_index[0]]

def set_cov(i, j, value):
    table_index = np.where(neighbour_table[i] == j)[0]
    if table_index.shape[0] == 1:
        covs[i, table_index[0]] = value

    table_index = np.where(neighbour_table[j] == i)[0]
    if table_index.shape[0] == 1:
        covs[j, table_index[0]] = value

dcovs = np.zeros((net.N_particles, 6))
def get_dcov(i, j):
    table_index = np.where(neighbour_table[i] == j)[0]
    if table_index.shape[0] == 0:
        return 0

    return dcovs[i, table_index[0]]

def set_dcov(i, j, value):
    table_index = np.where(neighbour_table[i] == j)[0]
    if table_index.shape[0] == 1:
        dcovs[i, table_index[0]] = value

    table_index = np.where(neighbour_table[j] == i)[0]
    if table_index.shape[0] == 1:
        dcovs[j, table_index[0]] = value

---

In [5]:
def calc_effective_states(i, j):
    phase_space = g2.phase_space
    states = np.repeat(np.expand_dims(means, axis = [0, 1]), phase_space.shape[0], axis = 0)
    states = np.repeat(states, phase_space.shape[1], axis = 1)

    states[:,:,i] = phase_space[:,:,0]
    states[:,:,j] = phase_space[:,:,1]
    
    return states

---
$$
\langle I_{ij} \rangle
$$

In [6]:
def calc_R_island(i, j):
    states = calc_effective_states(i, j)
    rates = net.calc_rate_island(states, i, j)
    return rates

def calc_R_island_inv(i, j):
    states = calc_effective_states(i, j)
    rates = net.calc_rate_island(states, j, i)
    return rates

---
$$
\langle I_{ei} \rangle
$$

In [7]:
def calc_R_from_electrode(electrode_index):
    phase_space = g1.phase_space
    states = np.expand_dims(means, axis = 0)
    states = np.repeat(states, phase_space.shape[0], axis = 0)

    island_index = net.get_linear_indices(net.electrode_pos[electrode_index])
    states[:, island_index] = phase_space
    rates = net.calc_rate_from_electrode(states, electrode_index)
    return rates

def calc_R_to_electrode(electrode_index):
    phase_space = g1.phase_space
    states = np.expand_dims(means, axis = 0)
    states = np.repeat(states, phase_space.shape[0], axis = 0)

    island_index = net.get_linear_indices(net.electrode_pos[electrode_index])
    states[:, island_index] = phase_space
    rates = net.calc_rate_to_electrode(states, electrode_index)
    return rates

---
$$
\langle n_i I_{ei} \rangle
$$

In [8]:
def calc_nR_to_electrode(electrode_index):
    phase_space = g1.phase_space
    rates = calc_R_to_electrode(electrode_index)
    values = rates * phase_space 
    return values

def calc_nR_from_electrode(electrode_index):
    phase_space = g1.phase_space
    rates = calc_R_from_electrode(electrode_index)
    values = rates * phase_space 
    return values

---
$$
\langle n_i I_{ej} \rangle
$$

In [9]:
def calc_nR_from_electrode_2(i, electrode_index):
    phase_space = g2.phase_space
    island_index = net.get_linear_indices(net.electrode_pos[electrode_index])

    states = calc_effective_states(i, island_index)
    rates = net.calc_rate_from_electrode(states, electrode_index)

    return rates * phase_space[:,:,0]

def calc_nR_to_electrode_2(i, electrode_index):
    phase_space = g2.phase_space
    island_index = net.get_linear_indices(net.electrode_pos[electrode_index])

    states = calc_effective_states(i, island_index)
    rates = net.calc_rate_to_electrode(states, electrode_index)

    return rates * phase_space[:,:,0]

---
$$
\langle n_j I_{ij} \rangle
$$

In [10]:
def calc_nR_island(i, j):
    phase_space = g2.phase_space
    rates = calc_R_island(i, j)
    values = rates * phase_space[:,:,1] 
    return values

def calc_nR_island_inv(i, j):
    phase_space = g2.phase_space
    rates = calc_R_island_inv(i, j)
    values = rates * phase_space[:,:,1] 
    return values

---
$$
\langle n_i I_{ij} \rangle
$$

In [11]:
def calc_nR_island_alt(i, j):
    phase_space = g2.phase_space
    rates = calc_R_island(i, j)
    values = rates * phase_space[:,:,0] 
    return values

def calc_nR_island_inv_alt(i, j):
    phase_space = g2.phase_space
    rates = calc_R_island_inv(i, j)
    values = rates * phase_space[:,:,0] 
    return values

---
## Run

In [28]:
mf_means = mf.numeric_integration_solve(N = 30)

means = np.copy(mf_means)
vars = np.ones(net.N_particles)
covs = np.zeros((net.N_particles, 6))

In [29]:
dt = 0.07
for epoch in range(20):
    l_R = np.zeros(net.N_particles)
    r_R = np.zeros(net.N_particles)
    l_nR = np.zeros(net.N_particles)
    r_nR = np.zeros(net.N_particles)

    for i in range(net.N_particles):
        for j in neighbour_table[i]:
            if not j == -1: # all neighbour relations
                probs = g2.calc_prob(means[j], means[i], vars[j], vars[i],get_cov(i, j))
                l_R[i] += np.sum(probs * calc_R_island(j, i))
                l_nR[i] += np.sum(probs * calc_nR_island(j, i))
                r_R[i] += np.sum(probs * calc_R_island_inv(j, i))
                r_nR[i] += np.sum(probs * calc_nR_island_inv(j, i))
                
    l_R_electrodes = np.zeros(net.N_particles)
    r_R_electrodes = np.zeros(net.N_particles)
    l_nR_electrodes = np.zeros(net.N_particles)
    r_nR_electrodes = np.zeros(net.N_particles)

    for electrode_index, pos in enumerate(net.electrode_pos):
        i = net.get_linear_indices(pos)

        probs = g1.calc_prob(means[i], vars[i])
        l_R_electrodes[i] += np.sum(probs * calc_R_from_electrode(electrode_index)) 
        l_nR_electrodes[i] += np.sum(probs * calc_nR_from_electrode(electrode_index))

        r_R_electrodes[i] += np.sum(probs * calc_R_to_electrode(electrode_index)) 
        r_nR_electrodes[i] += np.sum(probs * calc_nR_to_electrode(electrode_index))

    # islands
    I_islands = l_R - r_R
    I_dag_islands = l_R + r_R

    nI_islands = l_nR - r_nR

    # electrodes
    I_electrodes = l_R_electrodes - r_R_electrodes
    I_dag_electrodes = l_R_electrodes + r_R_electrodes

    nI_electrodes = l_nR_electrodes - r_nR_electrodes

    # total
    I = I_islands + I_electrodes
    I_dag = I_dag_islands + I_dag_electrodes
    nI = nI_islands + nI_electrodes
    
    d_mean = I
    d_var = (2 * nI + I_dag) - 2 * means * I

    dcovs = np.zeros((net.N_particles, 6)) # reset dcovs
  
    for i in range(net.N_particles):
        for j in neighbour_table[i]:
            if not j == -1:
                probs = g2.calc_prob(means[i], means[j], vars[i], vars[j], get_cov(i, j))
                probs2 = g2.calc_prob(means[j], means[i], vars[j], vars[i], get_cov(i, j))
                island_indices = net.get_linear_indices(net.electrode_pos)


                # < ni Ij >
                dcov = means[i] * I_islands[j]
                dcov -= means[i] * np.sum(probs * (calc_R_island(i, j) - calc_R_island_inv(i, j)))
                dcov += np.sum(probs * (calc_nR_island_alt(i, j) - calc_nR_island_inv_alt(i, j)))

                electrode_index = np.where(island_indices == j)[0]
                if electrode_index.shape[0] == 1:
                    dcov += np.sum(probs * (calc_nR_from_electrode_2(i, electrode_index[0]) - calc_nR_to_electrode_2(i, electrode_index[0])))


                # < nj Ii >
                dcov += means[j] * I_islands[i]
                dcov -= means[j] * np.sum(probs2 * (calc_R_island(j, i) - calc_R_island_inv(j, i)))
                dcov += np.sum(probs2 * (calc_nR_island_alt(j, i) - calc_nR_island_inv_alt(j, i)))

                electrode_index = np.where(island_indices == i)[0]
                if electrode_index.shape[0] == 1:
                    dcov += np.sum(probs2 * (calc_nR_from_electrode_2(j, electrode_index[0]) - calc_nR_to_electrode_2(j, electrode_index[0])))

                # < I^dag_ij >
                dcov -= np.sum(probs * (calc_R_island(i, j) + calc_R_island_inv(i, j)))

                set_dcov(i, j, dcov - d_mean[i] * means[j] - means[i] * d_mean[j])

    means += dt * d_mean
    vars += dt * d_var
    covs += dt * dcovs

    vars = np.where(vars < 0, 0, vars)

    if  epoch % 1 == 0:
        print(np.abs(d_mean).max() , np.abs(d_var).max(), np.abs(dcovs).max())

2.073654138885941 8.716554547732084 2.229202826985741
0.4919202318017213 1.7192192448359167 1.1757109480790093
0.1931019462632897 1.1811300141055348 0.5159628825786058
0.14488593627443447 0.4494517825908577 0.3313258159351238
0.12001349081539114 0.3362181440786046 0.20861123855422675
0.09587430638411676 0.2557613144717991 0.15945760688622868
0.07607001254814272 0.19580403431891885 0.14363945541179712
0.05898757172105873 0.1497263324211599 0.1317461096655582
0.04498091642509516 0.11401102688172443 0.12251627278476102
0.0355998923329468 0.08636509919736526 0.11513663430601856
0.034593802639050764 0.06514996754971797 0.1090979778258605
0.03344676936244578 0.049018188409353375 0.10399394792014503
0.032248435659666175 0.036875295235505846 0.09960057029287389
0.031069204431198094 0.027794154077604032 0.09591656519824496
0.029937768314868107 0.022249460202530356 0.09485280644357584
0.029053681061431635 0.02063892537137335 0.08945337937242884
0.0283310602076845 0.01870780643833585 0.0856292455

In [31]:
means

array([ 1.04378846,  0.28500306, -0.41355683,  0.14132274, -0.07578889,
       -0.29040089, -0.75628474, -0.39812275, -0.52754405])

In [32]:
mf_means

array([ 1.14150097,  0.30730644, -0.45897255,  0.17305524, -0.04159311,
       -0.18970848, -0.81723591, -0.25285378, -0.68534517])