In [None]:
import numpy as np

In [None]:
def S(X):
    return X**3

class Population:
    
    def __init__(self, M, N=200, g=1.5, tau=0.030, dt=1e-4, bias=4, sta_dt=0.005, eta=0.5, alpha_trace=0.33):
        self.M, self.N = M, N
        self.g   = g
        self.tau = tau
        self.t, self.dt, self.sta_dt = 0, dt, sta_dt
        self.eta, self.alpha_trace = eta, alpha_trace
    
        self.bias = np.random.choice(N, size=5, replace=False)
        self._ouput = self.bias[-1] # output neuro
        self.bias = self.bias[:-1]
        
        self.B = np.random.uniform(low=-1, high=1, size=(N, M))
        self.J = np.random.normal(loc=0.0, scale=self.g/self.N**2, size=(N, N))
        
        self.reset()

    def reset(self):
        self.X = np.random.uniform(low=-0.1, high=0.1, size=(self.N, 1))
        self.e = np.zeros((self.N, self.N))
        
        self._sta_X = [self.X]
        self.R_expected = {} # history of rewards by trial type
        self.t = 0
        
    def output(self):
        return np.tanh(self.X[self._ouput][0])
        
    def sta_X(self):
        """Short-term average of X"""
        self._sta_X = self._sta_X[-int(self.sta_dt/self.dt):]
        return np.mean(self._sta_X, axis=0)
        
    def step(self, U):
        """Advance time by self.dt"""
        self.X[self.bias] = 1.0
        
        A = np.tanh(self.X) # activation
        self.X += self.dt/self.tau*(-self.X + np.dot(self.J, A) + np.dot(self.B, U))
        self.X = np.clip(self.X, -1.0, 1.0) # necessary ?
        self.t += self.dt
        
        self._sta_X.append(self.X)
        self.e += np.dot(A, S(self.X - self.sta_X()).T)
        
    def end_trial(self, inputs, R):
        inputs = tuple(tuple(map(tuple, e)) for e in inputs) # hashable inputs
        if not inputs in self.R_expected:
            self.R_expected[inputs] = 0.0
        self.R_expected[inputs] = self.alpha_trace*self.R_expected[inputs] + (1-self.alpha_trace)*R
        delta_J = self.eta*self.e*(R - self.R_expected[inputs])
        self.J += np.clip(delta_J, -1e-4, 1e-4)

In [None]:
def trial(pop, U_1, U_2):
    U_zero = np.zeros((pop.M, 1))
    outputs = []
    
    while pop.t <= 0.200:
        pop.step(U_1)
    while pop.t <= 0.400:
        pop.step(U_zero)
    while pop.t <= 0.600:
        pop.step(U_2)
    while pop.t <= 0.700:
        pop.step(U_zero)
    while pop.t <= 1.0:
        pop.step(U_zero)
        outputs.append(pop.output())
    
    R = 1.0
    if np.all(U_1 == U_2):
        R = -1.0
        pop.end_trial([U_1, U_2], R)
        
    return R, outputs

In [None]:
np.random.seed(0)

U_A    = np.array([[1], [0]])
U_B    = np.array([[0], [1]])

pop = Population(2, N=200)
for i in range(10000): # takes a long time.
    U_1 = [U_A, U_B][np.random.choice(2)]
    U_2 = [U_A, U_B][np.random.choice(2)]
    R, outputs = trial(pop, U_1, U_2)
    
    pop.reset()
    print(np.abs(R - np.mean(outputs)))