# State Space Model of Learning - Performance Evaluation

This notebook demonstrates usage of the framework for the state space model of learning problem. Additionally, it is used to create Fig. 1 in section IV.A.

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import bernoulli as bern
import time
import seaborn as sns
sns.set(rc={'image.cmap': 'jet'},font_scale=1)
%matplotlib inline

import sys
sys.path
sys.path.append('../python/')
from ssml import ssml
from ssml_particle import particle_smooth
from ssml_kalman import kalman_smooth

# Temporary to avoid the "if self._edgecolors == str('face')" FutureWarning
import warnings
warnings.filterwarnings('ignore')

## *Generative Model*
First we'll create functions for generating a learning state and the subsequent observations.

In [2]:
def gen_learning_state(N=50,sparse_variations=False,plot=False,
                       gamma=0.1,phi=0.99,sigv=0.03,sigchi=0.15,p=0.05):

    X = np.zeros((N,))

    if sparse_variations:
        for n in range(1,len(X)):
            if np.random.uniform() < p:
                X[n] = X[n-1] + sigchi*(np.random.chisquare(2))
            else:
                X[n] = X[n-1]

    else:
        V = np.sqrt(sigv)*np.random.randn(N) # Gaussian vector for learning state random walk
        X[0] = gamma + V[0]
        for n in range(1,len(X)):
            X[n] = gamma + phi*X[n-1] + V[n]

    if plot:
        plt.figure(figsize=(15,4))
        plt.plot(range(1,N+1),X,'o')
        plt.xlim(0.75,N+0.5)
        plt.xticks(range(1,N+1))
        plt.xlabel('Trial Number')
        plt.ylabel('Learning State')
        plt.tight_layout();
    
    return X

In [3]:
def gen_reaction_time(X,delta=0.369,h=-0.38,sigg=0.75,plot=False):
    N = len(X)
    G = np.sqrt(sigg)*np.random.randn(N)
    Q = delta + h*X + G

    if plot:
        plt.figure(figsize=(15,4))
        plt.plot(range(1,N+1),Q,'o')
        plt.xlim(0.75,N+0.5)
        plt.xticks(range(1,N+1))
        plt.xlabel('Trial Number')
        plt.ylabel('Log Reaction Time')
        plt.tight_layout();
    
    return Q

In [4]:
def gen_trial_outcome(X,mu=-1.4170,eta=1.75,plot=False):
    N = len(X)
    p = np.exp(mu + eta*X)/(1 + np.exp(mu+eta*X))
    M = np.zeros((N,))
    col = []
    for n in range(len(M)):
        M[n] = bern.rvs(p[n],size=1)
        if M[n] == 1: col.append('b')
        else: col.append('r')

    if plot:
        plt.figure(figsize=(15,4))
        plt.plot([0]*(N+2),'k--',[1]*(N+2),'k--')
        plt.scatter(range(1,N+1),M,c=col,s=50)
        plt.xlim(0.75,N+0.5)
        plt.xticks(range(1,N+1))
        plt.ylim([-.1,1.1])
        plt.xlabel('Trial Number')
        plt.ylabel('Correct/Incorrect (1/0) Outcome')
        plt.tight_layout();
    return M

In [5]:
def gen_spiking(X,T=5,Del=0.001,psi=-3.5,g=2.2,c=[-20,-5,1,3],plot=False):
    J = int(T/Del)
    N = len(X)
    R = np.zeros((N,J))
    for n in range(N):
        for j in range(J):
            c_sum = 0
            for s in range(len(c)):
                if(j-(s+1)>=0):
                    c_sum += c[s]*R[n,j-(s+1)]
            lam_nj = np.exp(psi + g*X[n] + c_sum)
            R[n,j] = bern.rvs(lam_nj*Del*np.exp(-lam_nj*Del),size=1)

    if plot:
        R_vec = R.reshape(N*J,1)
        plt.figure(figsize=(15,4))
        plt.plot(np.arange(0,N*J/1000,1/1000),R_vec,linewidth=0.3)
        plt.xlim([0,N*J/1000])
        plt.xticks(range(0,130,5))
        plt.xlabel('Time (s)')
        plt.ylabel('Spiking Activity')
        plt.tight_layout();
    
    return R

## *Gaussian Learning State*

First we'll run multiple trials with the standard state space model of learning model with the Gaussian random walk. Start by defining the parameters:

In [18]:
# Learning State
N=25; gamma=0.1; phi=0.99; sigv=0.03
# Reaction time
delta=0.369; h=-0.38; sigg=0.75
# Trial outcomes
mu=-1.4170; eta=1.75
# Spiking 
T=5; Del=0.001; psi=-3.5; g=2.2; c=[-20,-5,1,3]; J = int(T/Del)

# ADMM params
rho = 30
max_iters = 25
verbosity = 0

# Kalman param
eps=0.0000001

# Particle param
num_p=100

# The parameters that get passed into each solution method
params = [gamma,phi,sigv,delta,h,sigg,mu,eta,psi,g,c,Del,J]

# Number of trials to conduct
num_trials = 50

In [19]:
results = {'admm':{'rmse':[],'time':[]},
           'kalman':{'rmse':[],'time':[]},
           'particle':{'rmse':[],'time':[]}}

for trial in range(num_trials):
    
    # generate learning states
    X = gen_learning_state(N)
    
    # create observations
    Q = gen_reaction_time(X)
    M = gen_trial_outcome(X)
    R = gen_spiking(X)
    
    # make predictions and store results
    start = time.time()
    x_admm = ssml(observations=(M,Q,R), rho=rho, params=params, max_iters=max_iters, verbosity=verbosity)
    results['admm']['time'].append(time.time()-start)
    results['admm']['rmse'].append(np.linalg.norm(X-x_admm)/N)
    
    start = time.time()
    x_kalman = kalman_smooth(obs=(M,Q,R),eps=eps,params=params)
    results['kalman']['time'].append(time.time()-start)
    results['kalman']['rmse'].append(np.linalg.norm(X-x_kalman)/N)
    
    start = time.time()
    x_particle = particle_smooth(obs=(M,Q,R),params=params,num_p=num_p)
    results['particle']['time'].append(time.time()-start)
    results['particle']['rmse'].append(np.linalg.norm(X-x_particle)/N)

    print('\n---------- Trial %i ----------'%(trial+1))
    print('ADMM - RMSE: %0.3f, Time: %0.1f'%(results['admm']['rmse'][trial],results['admm']['time'][trial]))
    print('Kalman - RMSE: %0.3f, Time: %0.1f'%(results['kalman']['rmse'][trial],results['kalman']['time'][trial]))
    print('Particle - RMSE: %0.3f, Time: %0.1f'%(results['particle']['rmse'][trial],results['particle']['time'][trial]))


---------- Trial 1 ----------
ADMM - RMSE: 0.047, Time: 1.9
Kalman - RMSE: 0.050, Time: 2.9
Particle - RMSE: 0.039, Time: 56.8

---------- Trial 2 ----------
ADMM - RMSE: 0.034, Time: 1.8
Kalman - RMSE: 0.040, Time: 2.7
Particle - RMSE: 0.044, Time: 57.5

---------- Trial 3 ----------
ADMM - RMSE: 0.060, Time: 1.9
Kalman - RMSE: 0.055, Time: 2.8
Particle - RMSE: 0.082, Time: 57.3

---------- Trial 4 ----------
ADMM - RMSE: 0.026, Time: 1.8
Kalman - RMSE: 0.023, Time: 2.9
Particle - RMSE: 0.050, Time: 57.2

---------- Trial 5 ----------
ADMM - RMSE: 0.033, Time: 1.9
Kalman - RMSE: 0.032, Time: 2.8
Particle - RMSE: 0.064, Time: 56.3

---------- Trial 6 ----------
ADMM - RMSE: 0.026, Time: 1.9
Kalman - RMSE: 0.027, Time: 3.0
Particle - RMSE: 0.032, Time: 57.2

---------- Trial 7 ----------
ADMM - RMSE: 0.023, Time: 1.9
Kalman - RMSE: 0.023, Time: 2.8
Particle - RMSE: 0.037, Time: 56.8

---------- Trial 8 ----------
ADMM - RMSE: 0.032, Time: 1.8
Kalman - RMSE: 0.032, Time: 2.7
Particle - 

In [20]:
print('ADMM - RMSE: %0.3f, Time: %0.1f'%(np.mean(results['admm']['rmse']),np.mean(results['admm']['time'])))
print('Kalman - RMSE: %0.3f, Time: %0.1f'%(np.mean(results['kalman']['rmse']),np.mean(results['kalman']['time'])))
print('Particle - RMSE: %0.3f, Time: %0.1f'%(np.mean(results['particle']['rmse']),np.mean(results['particle']['time'])))

ADMM - RMSE: 0.033, Time: 1.9
Kalman - RMSE: 0.034, Time: 2.9
Particle - RMSE: 0.040, Time: 57.0


## *Sparse Variation Learning State*

In [21]:
# Learning State
N=50; gamma=0.1; phi=0.99; sigv=0.03; sigchi = 0.15; p = 0.05
# Reaction time
delta=0.369; h=-0.38; sigg=0.75
# Trial outcomes
mu=-1.4170; eta=1.75
# Spiking 
T=5; Del=0.001; psi=-3.5; g=2.2; c=[-20,-5,1,3]; J = int(T/Del)

# ADMM params
rho = 30
max_iters = 20
verbosity = 0
beta = 25

# Kalman param
eps=0.0000001

# Particle param
num_p=100

# The parameters that get passed into each solution method
params = [gamma,phi,sigv,delta,h,sigg,mu,eta,psi,g,c,Del,J,beta]
kalman_params = [0,1,4*(sigchi**2),delta,h,sigg,mu,eta,psi,g,c,Del,J]
particle_params = [gamma,phi,sigv,delta,h,sigg,mu,eta,psi,g,c,Del,J,p,sigchi]

# Number of trials to conduct
num_trials = 20

In [22]:
results = {'admm':{'rmse':[],'time':[]},
           'kalman':{'rmse':[],'time':[]},
           'particle':{'rmse':[],'time':[]}}

for trial in range(num_trials):
    
    # generate learning states
    X = gen_learning_state(N,sparse_variations=True)
    
    # create observations
    Q = gen_reaction_time(X)
    M = gen_trial_outcome(X)
    R = gen_spiking(X)
    
    # make predictions and store results
    start = time.time()
    x_admm = ssml(observations=(M,Q,R), rho=rho, params=params, max_iters=max_iters, verbosity=verbosity,sparse=True)
    results['admm']['time'].append(time.time()-start)
    results['admm']['rmse'].append(np.linalg.norm(X-x_admm)/N)
    
    start = time.time()
    x_kalman = kalman_smooth(obs=(M,Q,R),eps=eps,params=kalman_params)
    results['kalman']['time'].append(time.time()-start)
    results['kalman']['rmse'].append(np.linalg.norm(X-x_kalman)/N)
    
    start = time.time()
    x_particle = particle_smooth(obs=(M,Q,R),params=particle_params,num_p=num_p,sparse=True,forward_only=True)
    results['particle']['time'].append(time.time()-start)
    results['particle']['rmse'].append(np.linalg.norm(X-x_particle)/N)

    print('\n---------- Trial %i ----------'%(trial+1))
    print('ADMM - RMSE: %0.3f, Time: %0.1f'%(results['admm']['rmse'][trial],results['admm']['time'][trial]))
    print('Kalman - RMSE: %0.3f, Time: %0.1f'%(results['kalman']['rmse'][trial],results['kalman']['time'][trial]))
    print('Particle - RMSE: %0.3f, Time: %0.1f'%(results['particle']['rmse'][trial],results['particle']['time'][trial]))


---------- Trial 1 ----------
ADMM - RMSE: 0.010, Time: 2.0
Kalman - RMSE: 0.041, Time: 5.8
Particle - RMSE: 0.016, Time: 113.8

---------- Trial 2 ----------
ADMM - RMSE: 0.031, Time: 2.2
Kalman - RMSE: 0.030, Time: 5.8
Particle - RMSE: 0.030, Time: 113.8

---------- Trial 3 ----------
ADMM - RMSE: 0.028, Time: 2.3
Kalman - RMSE: 0.030, Time: 6.1
Particle - RMSE: 0.045, Time: 113.2

---------- Trial 4 ----------
ADMM - RMSE: 0.020, Time: 2.1
Kalman - RMSE: 0.021, Time: 5.7
Particle - RMSE: 0.025, Time: 113.3

---------- Trial 5 ----------
ADMM - RMSE: 0.007, Time: 1.9
Kalman - RMSE: 0.027, Time: 5.7
Particle - RMSE: 0.014, Time: 113.0

---------- Trial 6 ----------
ADMM - RMSE: 0.026, Time: 2.0
Kalman - RMSE: 0.043, Time: 5.8
Particle - RMSE: 0.028, Time: 113.5

---------- Trial 7 ----------
ADMM - RMSE: 0.017, Time: 2.3
Kalman - RMSE: 0.020, Time: 6.1
Particle - RMSE: 0.040, Time: 115.0

---------- Trial 8 ----------
ADMM - RMSE: 0.032, Time: 2.1
Kalman - RMSE: 0.034, Time: 5.8
Part

In [23]:
print('ADMM - RMSE: %0.3f, Time: %0.1f'%(np.mean(results['admm']['rmse']),np.mean(results['admm']['time'])))
print('Kalman - RMSE: %0.3f, Time: %0.1f'%(np.mean(results['kalman']['rmse']),np.mean(results['kalman']['time'])))
print('Particle - RMSE: %0.3f, Time: %0.1f'%(np.mean(results['particle']['rmse']),np.mean(results['particle']['time'])))

ADMM - RMSE: 0.022, Time: 2.1
Kalman - RMSE: 0.032, Time: 5.9
Particle - RMSE: 0.030, Time: 113.9
