In [267]:
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import numpy as np
from scipy.ndimage.filters import convolve1d
%matplotlib inline

In [269]:
class NeuralField:
    
    def __init__(self, tau, a, b, d, k):
        
        self.tau = tau
        self.a = a
        self.b = b
        self.d = d
        self.k = k
        
    def kernel(self, x):
        '''Interaction kernel, Gabor filter'''
        
        return self.a * (np.exp(-(x**2)/(4*self.b**2))*np.cos(self.k*x))/(self.b*np.sqrt(np.pi))
    
    def current(self, x):
        '''Input'''
        
        return np.exp(-(x**2)/(4*self.d**2))/(2*self.d*np.sqrt(np.pi))
    
    def pad(self, f):
        '''Pad a vector for linear 1D convolution'''
        out = np.zeros(len(f)*2)
        out[:len(f)] = f
        return out
    
    def convolve(self, f):
        '''Linear 1D convolution'''
        r = np.zeros(len(f))
        for k in range(0, len(f)):
            for p in range(len(f)):
                if (k-p) < 0:
                    r[k] += 0
                else:
                    r[k] += f[p]*self.kernel(f[k-p])
        return r
    
    def simulate(self, A, B, dt, N):
        '''Simulate the Neural Field'''
        t = np.arange(A, B, dt)
        dx = (B-A)/N
        du = np.zeros((len(t), N))
        
        for i in range(1, len(t)):
            this_convolution = self.convolve(du[(i-1), :])*dx
            du[i, :] = du[(i-1), :] + (dt/self.tau)*(-du[(i-1), :] + this_convolution.T + self.current(du[(i-1), :]))
        return du
            

In [270]:
tau = 10
a = 1
b = 0.6
d = 2
k = 4
f = NeuralField(tau, a, b, d, k)

In [271]:
A = -10
B = 10
dt = 1
N = 200
du = f.simulate(A, B, dt, N)

In [272]:
layout = go.Layout(scene = dict(
                    xaxis = dict(
                        title = 'Space',
                        tickmode = 'array',),
                    yaxis_title='Time / ms',
                    zaxis_title='Firing rate'),
                    width=700,
                    margin=dict(r=20, b=10, l=10, t=10))

fig = go.Figure(data=go.Surface(z=du), layout=layout)
fig.show()