# 2D SSN Model

1. Imports

In [1]:
import jax.numpy as np
import jax
import matplotlib.pyplot as plt
import time, os, json
import pandas as pd
from scipy import stats 
from tqdm import tqdm
import seaborn as sns
from jax import random
from jax.config import config 

config.update('jax_debug_nans', True)
from SSN_classes_jax import SSN2DTopoV1_AMPAGABA_ONOFF
from util import GaborFilter, BW_Grating, find_A, create_gabor_filters, create_gratings

# MODEL TRAINING

## --> INCLUDE NUMBER OF TRIALS

1. Define parameters

In [7]:
#Convergence parameters

#Network parameters
class ssn_pars():
    n = 2
    k = 0.04
    tauE = 30 # in ms
    tauI = 10 # in ms
    psi = 0.774
    tau_s = np.array([5, 7, 100]) #in ms, AMPA, GABA, NMDA current decay time constants
    

#Grid parameters
class grid_pars():
    gridsize_Nx = 9 # grid-points across each edge # gives rise to dx = 0.8 mm
    gridsize_deg = 2 * 1.6 # edge length in degrees
    magnif_factor = 2  # mm/deg
    hyper_col = 0.8 # mm   
    sigma_RF = 0.4 # deg (visual angle)

# Caleb's params for the full (with local) model:
Js0 = [1.82650658, 0.68194475, 2.06815311, 0.5106321]
gE, gI = 0.57328625, 0.26144141

sigEE, sigIE = 0.2, 0.40
sigEI, sigII = .09, .09
conn_pars = dict(
    PERIODIC = False,
    p_local = [.4, 0.7], # [p_local_EE, p_local_IE],
    sigma_oris = 1000) # sigma_oris


make_J2x2 = lambda Jee, Jei, Jie, Jii: np.array([[Jee, -Jei], [Jie,  -Jii]]) * np.pi * ssn_pars.psi
J_2x2 = make_J2x2(*Js0)
s_2x2 = np.array([[sigEE, sigEI],[sigIE, sigII]])

#Positive reparameterization
signs=np.array([[1, -1], [1, -1]])
logJ_2x2 =np.log(J_2x2*signs)
logs_2x2 = np.log(s_2x2)


#Sigmoid parameters
N_neurons = 25
key = random.PRNGKey(10)
w_sig = random.normal(key, shape = (N_neurons,)) / np.sqrt(N_neurons)
b_sig = 0.0

#Gabor parameters 
sigma_g= 0.5
k = np.pi/(6*sigma_g)

#Parameters shared with stimuli
general_pars = dict(k=k , edge_deg=3.2,  degree_per_pixel=0.05)

#Parameters exclusive to Gabor filters
filter_pars = dict(sigma_g = sigma_g, conv_factor = grid_pars.magnif_factor)

#Concatenate all parameters
filter_pars.update(general_pars) 

2. Create training data

In [3]:
#Stimuli parameters
    
#Add jitter to reference and target
key = random.PRNGKey(86)
jitter_val = 5
jitter =random.uniform(key, minval=- jitter_val , maxval= jitter_val)

stimuli_pars = dict(outer_radius=3, inner_radius=2.5, grating_contrast=0.99, snr = 0)
stimuli_pars.update(general_pars)


#Create training data
training_data=np.array(([55, 50], [55,60], [55, 50], [55,60], [55, 50]))
#training_data=np.array(([55,50]))
gratings, labels = create_gratings(training_data, jitter_val= 5, **stimuli_pars)

3. TRAINING!

In [4]:
def sigmoid(x):
    return 1/(1+np.exp(x))

def binary_loss(n, x):
    return - n*np.log(x) - (1-n)*np.log(1-x)

def model(logJ_2x2, logs_2x2, w_sig, b_sig, ssn_pars, grid_pars, conn_pars, train_data, labels, filter_pars,  **conv_pars):
    
    signs=np.array([[1, -1], [1, -1]])
    
    J_2x2 =np.exp(logJ_2x2)*signs
    s_2x2 = np.exp(logs_2x2)
    
    total_loss=0
    
    for i in range(len(train_data)):
        
        #Initialise network
        ssn=SSN2DTopoV1_AMPAGABA_ONOFF(ssn_pars=ssn_pars, grid_pars=grid_pars, conn_pars=conn_pars, filter_pars=filter_pars, J_2x2=J_2x2, s_2x2=s_2x2)
                                           
        
        #Apply Gabor filters to stimuli
        output_ref=np.matmul(ssn.gabor_filters, train_data[i, 0].ravel())*ssn.A
        output_target=np.matmul(ssn.gabor_filters, train_data[i,1].ravel())*ssn.A

        #Rectify output
        SSN_input_ref=np.maximum(0, output_ref)
        SSN_input_target=np.maximum(0, output_target)
        
        #Input to SSN
        r_init = np.zeros(SSN_input_ref.shape[0])
        
        fp_ref, _ = ssn.fixed_point_r(SSN_input_ref, r_init=r_init, **conv_pars)
        x_ref = ssn.apply_bounding_box(fp_ref, size=3.2)
        #print('x_ref:{}'.format(fp_ref))

        fp_target, _ = ssn.fixed_point_r(SSN_input_target, r_init=r_init, **conv_pars)
        x_target = ssn.apply_bounding_box(fp_target, size=3.2)
        #print('x_target:{}'.format(fp_target))

        #Combine reference and target 
        #x = fp_ref + fp_target

        #Apply sigmoid function
        x = sigmoid( np.dot(w_sig, (x_ref.ravel() - x_target.ravel())) + b_sig)
        print('x = ', x)

        #Calculate binary cross entropy loss
        total_loss+=np.sum(binary_loss(labels[i], x))
        
    print('Loss {}'.format(total_loss))
   
    #check what indices are nan
    #indices = np.argwhere(np.isnan(loss))
    
    return total_loss
    

def train_SSN(logJ_2x2, logs_2x2, w_sig, b_sig, ssn_pars, grid_pars, conn_pars, train_data, labels, filter_pars, conv_pars, epochs=6, eta=10e-4):
    
    #Initialize loss
    #total_loss = 0
    
    #define grad function with respect to params to differentiate
    
    grad_loss=jax.grad(model, argnums=(0,1,2,3))
    
    for i in range(epochs):
        
        
        dJ, ds, dw_sig, db_sig =grad_loss(logJ_2x2, logs_2x2,  w_sig, b_sig, ssn_pars, grid_pars, conn_pars, train_data, labels, filter_pars,  **conv_pars)

        #compute loss
        #model_out=model(J_2x2, s_2x2,  w_sig, b_sig, ssn_pars, grid_pars, conn_pars, train_data, labels, filter_pars,  **conv_pars)

        #update parameters using gradient descent
        logJ_2x2 = logJ_2x2 - dJ*eta
        
        logs_2x2 = logs_2x2 - ds*eta
        
        w_sig = w_sig - dw_sig*eta
        b_sig = b_sig - db_sig*eta
        
    signs=np.array([[1, -1], [1, -1]])    
    J_2x2 = np.exp(logJ_2x2)*signs
    s_2x2 = np.exp(logs_2x2)
    
    return J_2x2, s_2x2, w_sig, b_sig

In [8]:
conv_pars=dict(dt = 1, xtol = 1e-5, Tmax = 5, verbose=False, silent=True)

model_out=model(logJ_2x2, logs_2x2,  w_sig, b_sig, ssn_pars, grid_pars, conn_pars, gratings, labels, filter_pars,  **conv_pars)

x =  0.5062364
x =  0.4794171
x =  0.5099783
x =  0.48206148
x =  0.50236434
Loss 3.3532731533050537


In [5]:
type(b_sig)

float

In [9]:
J, s, w, b= train_SSN(logJ_2x2, logs_2x2, w_sig, b_sig, ssn_pars, grid_pars, conn_pars, gratings, labels, filter_pars, conv_pars, epochs=1)

x =  Traced<ConcreteArray(0.5062363743782043, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.5062364, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f9bf82200e0>, in_tracers=(Traced<ConcreteArray(1.0, dtype=float32):JaxprTrace(level=1/0)>, Traced<ConcreteArray(0.2562752962112427, dtype=float32):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7f9bf831f5e0; to 'JaxprTracer' at 0x7f9c00279680>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': 'true_divide', 'donated_invars': (False, False, False), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[] b:f32[] c:f32[]. let
    d:f32[] = neg c
    e:f32[] = mul d a
    f:f32[] = mul e b
  in (f,) }}, effects=set(), source_info=SourceInfo(tracebac

In [7]:
ssn=SSN2DTopoV1_AMPAGABA_ONOFF(ssn_pars=ssn_pars, grid_pars=grid_pars, conn_pars=conn_pars, filter_pars=filter_pars, J_2x2=J_2x2, s_2x2=s_2x2)

#Apply Gabor filters to stimuli
output_ref=np.matmul(ssn.gabor_filters, gratings[1, 0].ravel())*ssn.A
output_target=np.matmul(ssn.gabor_filters, gratings[1,1].ravel())*ssn.A

#Rectify output
SSN_input_ref=np.maximum(0, output_ref)
SSN_input_target=np.maximum(0, output_target)

#Input to SSN
r_init = np.zeros(SSN_input_ref.shape[0])

fp_ref, _ = ssn.fixed_point_r(SSN_input_ref, r_init=r_init, **conv_pars)

#fp_target, _ = ssn.fixed_point_r(SSN_input_target, r_init=r_init, **conv_pars)

Did not reach fixed point.


In [8]:
np.sum(ssn.W)

DeviceArray(2128.79824252, dtype=float64)