In [1]:
import numpy as np
from numpy.linalg import norm

In [2]:
def equation_pos_root(a, b, c):
   
    dis = b * b - 4 * a * c
    sqrt_val = (abs(dis))**.5
    return (-b + sqrt_val)/(2 * a)
          

In [3]:
class TCM_A():
    
    def __init__(self, dim):
        
        '''
        :param dim: dimensionality of context and item vectors 
        '''
        
        self.dim = dim
        self.t_prev = np.ones(dim)
        self.t_prev /= norm(self.t_prev)
        self.MFT_pre = np.identity(dim)
        self.MTF_pre = np.identity(dim)
        
        self.gamma_ft = .8
        self.gamma_tf = .8
        
        self.beta = 0.8
        
    def begin_list_presentation(self):
        
        self.MFT_exp = np.zeros((self.dim, self.dim))
        self.MTF_exp = np.zeros((self.dim, self.dim))
        self.update_M()
        
    def update_M(self):
        
        self.MFT = (1-self.gamma_ft)*self.MFT_pre + self.gamma_ft*self.MFT_exp
        self.MTF = (1-self.gamma_tf)*self.MTF_pre + self.gamma_tf*self.MTF_exp
        
    def compute_rho(self):
        
        cos_sim = np.dot(self.t_in, self.t_prev)
        c1 = 1 
        c2 = 2 * self.beta * cos_sim
        c3 = self.beta**2 - 1
        return equation_pos_root(c1, c2, c3)
        
    def context_evolution(self):

        # compute rho so that self.t_curr is unit norm 
        rho = self.compute_rho() 
        
        self.t_curr = rho*self.t_prev + self.beta*self.t_in
        assert round(norm(self.t_curr),2) == 1.0, print("Rho is incorrect")
        
        # set current context to previous context
        self.t_prev = self.t_curr 

    def normalize_vec(self, unnormalized):
        
        return unnormalized/np.linalg.norm(unnormalized)

    def compute_t_in(self, f_curr):
        
        '''
        :param np vector f_i: item presented/recalled at timestep i 
        '''
        
        self.t_in = self.normalize_vec(self.MFT@f_curr)
                
    def update_MTF_exp(self, lr, f_curr):
        
        '''
        :param float lr: learning rate 
        :param vector f_curr: current input 
        :param vector t_prev: previous context 
        '''
        self.MTF_exp += lr*np.outer(f_curr, self.t_prev.T)
        
    def update_MFT_exp(self, f_curr):
        
        self.MFT_exp += np.outer(self.t_prev, f_curr.T)
        
    def encode_item(self, lr, f_curr):
        
        # 1) bind existing context to item 
        self.update_MTF_exp(lr, f_curr)
        
        # 2) Compute t_in
        self.compute_t_in(f_curr)
        
        # 3) bind item to existing context 
        self.update_MFT_exp(f_curr)
        
        # 4) update context 
        self.context_evolution()
        
    def retrieve_item(self):
        
        f_retrieved = self.MTF@self.t_curr
        item = np.argmax(f_retrieved)
        f = np.zeros(self.dim)
        f[item] = 1
        self.compute_t_in(f)
        self.context_evolution()
        print(f)

In [91]:
tcm_a = TCM_A(3)
tcm_a.begin_list_presentation()
tcm_a.encode_item(1, np.asarray([1,0,0]))
tcm_a.encode_item(1, np.asarray([0,1,0]))
tcm_a.encode_item(1, np.asarray([0,0,1]))
tcm_a.retrieve_item()
print(tcm_a.t_curr)
tcm_a.retrieve_item()
print(tcm_a.t_curr)
tcm_a.retrieve_item()
print(tcm_a.t_curr)

[0. 0. 1.]
[0.05705558 0.10822976 0.99248727]
[0. 0. 1.]
[0.01148007 0.02177675 0.99969694]
[0. 0. 1.]
[0.00229657 0.00435641 0.99998787]
