# 2D SSN Model

1. Imports

In [1]:
import jax.numpy as np
import jax
import matplotlib.pyplot as plt
import time, os, json
import pandas as pd
from scipy import stats 
from tqdm import tqdm
import seaborn as sns
from jax import random

from SSN_classes_jax import SSN2DTopoV1_AMPAGABA_ONOFF
from util import GaborFilter, BW_Grating, find_A, create_gabor_filters

2. Create SSN network

In [2]:
#Network parameters
class ssn_pars():
    n = 2
    k = 0.04
    tauE = 30 # in ms
    tauI = 10 # in ms
    psi = 0.774
    tau_s = np.array([5, 7, 100]) #in ms, AMPA, GABA, NMDA current decay time constants
    


#Grid parameters
class grid_pars():
    gridsize_Nx = 9 # grid-points across each edge # gives rise to dx = 0.8 mm
    gridsize_deg = 2 * 1.6 # edge length in degrees
    magnif_factor = 2  # mm/deg
    hyper_col = 0.8 # mm   
    sigma_RF = 0.4 # deg (visual angle)

# Caleb's params for the full (with local) model:
Js0 = [1.82650658, 0.68194475, 2.06815311, 0.5106321]
gE, gI = 0.57328625, 0.26144141

sigEE, sigIE = 0.2, 0.40
sigEI, sigII = .09, .09

conn_pars = dict(
    PERIODIC = False,
    p_local = [.4, 0.7], # [p_local_EE, p_local_IE],
    sigma_oris = 1000) # sigma_oris



make_J2x2 = lambda Jee, Jei, Jie, Jii: np.array([[Jee, -Jei], [Jie,  -Jii]]) * np.pi * ssn_pars.psi
J_2x2 = make_J2x2(*Js0)
s_2x2 = np.array([[sigEE, sigEI],[sigIE, sigII]])

#Create network
ssn = SSN2DTopoV1_AMPAGABA_ONOFF(ssn_pars, grid_pars, conn_pars=conn_pars, J_2x2=J_2x2, s_2x2=s_2x2)




3. Create Gabor filters

In [3]:
#Gabor parameters 
sigma_g= 0.5
k= np.pi/(6*sigma_g)
general_pars = dict(k=k, edge_deg=3.2,  degree_per_pixel=0.05) #parameters shared with input stimuli


#Create filters
SSN_filters, A =create_gabor_filters(ssn, sigma_g=sigma_g, conv_factor = grid_pars.magnif_factor, **general_pars)

Average A is 0.0008299059715404468


3. Input target and reference

In [5]:
#Stimuli parameters
stimuli_pars = dict(outer_radius=3, inner_radius=2.5, grating_contrast=0.99)
stimuli_pars.update(general_pars)

#Create reference stimuli
ori_ref = 0
ref_grating=BW_Grating(ori_deg = ori_ref, **stimuli_pars).BW_image()

#Create target stimuli
ori_target= 10
target_grating=BW_Grating(ori_deg = ori_target, **stimuli_pars).BW_image()

# MODEL TRAINING

In [None]:
def sigmoid(x):
    return 1/(1+np.exp(-x))

def binary_loss(n, x):
    return n*np.log(x) + (1-n)*np.log(1-x)

def model(J_2x2, s_2x2, ssn_pars, grid_pars, conn_pars, general_pars, label, ref, target, **conv_pars):
    
    #Initialise network
    ssn=SSN2DTopoV1_AMPAGABA_ONOFF(ssn_pars, grid_pars, conn_pars=conn_pars, J_2x2=J_2x2, s_2x2=s_2x2)
   
    #Create Gabor filters for network
    SSN_filters, A =create_gabor_filters(ssn, sigma_g=sigma_g, conv_factor = grid_pars.magnif_factor, **general_pars)                               
                                       
    #Apply Gabor filters to stimuli
    output_ref=np.matmul(SSN_filters, ref.ravel())*A
    output_target=np.matmul(SSN_filters, target.ravel())*A
    
    #Rectify output
    SSN_input_ref=np.maximum(0, output_ref)
    SSN_input_target=np.maximum(0, output_target)
    
    #Input to SSN
    r_init = np.zeros(SSN_input_ref.shape[0])
    
    fp_ref, _ = ssn.fixed_point_r(SSN_input_ref, r_init=r_init, **conv_pars)
    
    fp_target, _ = ssn.fixed_point_r(SSN_input_target, r_init=r_init, **conv_pars)
    
    #Combine reference and target 
    x = fp_ref + fp_target
    
    #Apply sigmoid function
    x = sigmoid(0.1*x)
    
    #Calculate binary cross entropy loss
    loss=np.sum(binary_loss(label, x))
   
    #check what indices are nan
    #indices = np.argwhere(np.isnan(loss))
    
    return loss
    

def train_SSN(J_2x2, s_2x2, ssn_pars, grid_pars, conn_pars, general_pars, train_data, stimuli_pars, conv_pars):
    
    #Initialize loss
    total_loss = 0
    
       
    #find label
    if train_data[0] > train_data[1]:
        label=1
    else:
        label=0
        
    ref_grating=BW_Grating(ori_deg = train_data[0], **stimuli_pars).BW_image()
    target_grating = BW_Grating(ori_deg = train_data[1], **stimuli_pars).BW_image()
        
        #total_loss+=model(ssn, A, SSN_filters, label, ref_grating, target_grating, **conv_pars )
    
    grad_loss=jax.grad(model, argnums=(0,1))
    gradient=grad_loss(J_2x2, s_2x2, ssn_pars, grid_pars, conn_pars, general_pars, label, ref_grating, target_grating, **conv_pars)
    
    return gradient
    #update weights 

In [None]:

train_data=np.array([10,20])

train_SSN(J_2x2, s_2x2, ssn_pars, grid_pars, conn_pars, general_pars, train_data, stimuli_pars, conv_pars)

In [6]:
conv_pars=dict(dt = 1, xtol = 1e-5, Tmax = 600)
SSN_input_ref =np.matmul(SSN_filters, target_grating.ravel())*A
SSN_input_ref=np.maximum(0, SSN_input_ref)
r_init=np.zeros(SSN_input_ref.shape[0])
fp_ref, _ = ssn.fixed_point_r(SSN_input_ref, r_init=r_init, **conv_pars)

grad_ssn=jax.grad((ssn.fixed_point_r))
grad_ssn(SSN_input_ref, r_init=r_init, **conv_pars)


       max(abs(dx./max(abs(xvec), 1.0))) = 0.1498628854751587,   xtol=1e-05.

Did not reach fixed point.


TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray([2.39045306e-07 5.05174361e-02 3.54168625e+01 1.47604517e+03
 3.53433342e+01 4.93178144e-02 4.55805633e-07 1.40898908e-03
 3.44161719e-01 9.97807801e-05 3.79991978e-01 1.46471069e+03
 2.83114685e+02 1.46274438e+03 3.71743321e-01 3.45970882e-04
 6.38958882e-05 1.96450111e-03 1.36579215e-09 4.94456030e-02
 3.51516914e+01 1.46444324e+03 3.50581055e+01 4.76084426e-02
 3.62756909e-05 2.59627323e-05 1.28611634e-02 1.53637741e-10
 1.10259043e-06 4.41823229e-02 3.57002020e-01 2.58970540e-02
 4.03595768e-05 3.03160959e-06 1.17222185e-03 8.77917588e-01
 1.21548219e-05 1.10309920e-05 9.27108247e-03 4.94957045e-02
 2.59197364e-03 1.15707962e-06 1.18230097e-02 3.12171178e-04
 3.32510471e-03 3.49332261e+00 7.32741091e-06 2.74387912e-06
 1.01096899e-04 5.29806968e-03 3.85003159e-06 5.52617666e-06
 3.25750816e-03 5.01227856e+00 2.41945744e-01 3.02146364e-04
 9.41658072e-06 3.36052040e-07 1.35825780e-06 6.64325270e-08
 1.76019839e-05 1.08838431e-03 1.62803262e-04 2.26547308e-02
 1.64264720e-03 1.85137324e-04 1.56695936e-02 7.15451062e-01
 2.22016979e-06 1.01327966e-03 2.23005915e+00 7.91126966e-01
 5.04123306e+00 4.45488766e-02 2.39215661e-02 4.78328019e-02
 1.72335599e-02 6.92157680e-03 2.44008159e-04 1.83372939e+00
 1.95616283e+01 3.79712988e+03 1.09777878e+02 2.67002295e+03
 1.35775498e+04 2.65880493e+03 1.10153786e+02 3.73467236e+03
 1.19229942e+02 2.84601521e+00 6.83263672e+03 2.79380798e+02
 1.34788926e+04 4.32935900e+06 1.34695459e+04 2.80177765e+02
 6.64262061e+03 2.02965302e+02 3.41118026e+00 3.56081836e+03
 1.10421318e+02 2.64797974e+03 1.34977988e+04 2.64846924e+03
 1.10255013e+02 3.50496948e+03 1.18950562e+02 3.01873517e+00
 6.62234680e+02 4.82003857e+03 1.09844521e+02 2.80528381e+02
 1.09617119e+02 4.74906738e+03 6.55362976e+02 2.70264378e+01
 8.97035313e+00 5.15718269e+01 6.45551758e+02 3.28150317e+03
 5.16179639e+03 3.30050366e+03 6.39274292e+02 5.09981346e+01
 3.41962957e+00 1.42685509e+00 1.01157246e+01 2.33161469e+01
 1.10907883e+02 1.90045609e+02 1.08419716e+02 2.30633316e+01
 2.79284310e+00 1.17048889e-01 1.94730873e+01 1.74634588e+00
 4.49147314e-01 1.55767703e+00 2.50792360e+00 1.58440232e+00
 4.48806077e-01 1.45511791e-01 1.35346794e+00 2.90576243e+00
 2.38838717e-01 4.18374613e-02 3.52283823e-03 2.95951460e-02
 8.89666438e-01 4.76098806e-02 1.34586644e+00 4.76615868e+01
 1.10827850e+02 1.43388319e+01 1.19029693e-01 3.56732607e-02
 5.71300238e-02 3.18590216e-02 8.93243849e-02 2.88715529e+00
 1.14536575e+02 9.37667389e+01 1.30518221e-10 1.29126704e-07
 9.43186460e-05 3.93137056e-03 9.41418984e-05 1.27508997e-07
 6.05227999e-13 3.23746319e-09 5.00035000e+00 1.84663360e-13
 1.00962234e-06 3.90111632e-03 7.08393578e-04 3.89609602e-03
 9.93981416e-07 6.96998459e-09 7.74617176e-11 3.42199957e-09
 1.13464117e-11 1.30591786e-07 9.36589786e-05 3.90049233e-03
 9.34074851e-05 1.27706102e-07 5.01154728e-13 1.34228934e-10
 3.50442271e-08 2.35068584e-12 9.94621916e-15 1.22035175e-07
 9.51181676e-07 9.61299804e-08 1.11253324e-12 9.48410729e-13
 2.91648905e-09 2.86607862e+00 4.37594676e-13 6.09899013e-13
 1.51472002e-09 2.10111040e-07 2.58265354e-10 8.70264555e-11
 2.08061257e-09 2.24356020e-11 5.20397186e+00 4.11031609e-09
 6.17094960e-12 1.25937732e-11 1.82825005e-10 3.73396247e-08
 2.50152806e-12 6.60852015e-13 4.24275895e-05 4.88129648e-10
 6.55873775e+00 2.35146031e-01 3.65504176e-13 1.53000546e-11
 8.54772086e-11 6.22795912e-12 1.76353077e-03 2.91543256e-09
 4.86712635e-01 5.91308951e-01 1.48966476e-01 2.95371190e-02
 2.63620436e-01 9.36421826e-02 3.79317403e-02 1.12129085e-01
 5.94634287e-07 1.55835085e-06 9.47777967e-10 1.59048790e-03
 1.17888238e-04 6.71944988e-04 1.00038422e-04 1.62846624e-14
 3.18903293e-10 4.51858523e-06 4.73535110e-05 1.08632715e-02
 2.92142475e-04 7.11128861e-03 3.61628532e-02 7.08157336e-03
 2.93209887e-04 1.06833652e-02 3.40949657e-04 2.31939468e+01
 1.95433274e-02 7.44041230e-04 3.58998738e-02 1.23857145e+01
 3.58756483e-02 7.46334612e-04 1.90099515e-02 5.80490974e-04
 9.69203757e-06 1.01864180e-02 2.94046564e-04 7.05297943e-03
 3.59505154e-02 7.05426000e-03 2.93697551e-04 1.00254146e-02
 3.40347964e-04 8.61463559e-06 1.89446309e-03 1.37866689e-02
 2.92766665e-04 7.47172337e-04 2.93093617e-04 1.35827754e-02
 1.87465397e-03 7.72828862e-05 1.25515757e+01 1.47473242e-04
 1.84648274e-03 9.37683042e-03 1.47797344e-02 9.43558663e-03
 1.82904524e-03 1.44434132e-04 9.62995182e-06 1.47838497e+01
 8.86684507e-02 6.66825217e-05 3.17285390e-04 5.43577538e-04
 3.11621989e-04 6.59603829e-05 7.97742086e-06 3.25425416e-02
 2.38146409e-02 2.03193035e+01 5.44075012e-01 4.44855868e-06
 7.18155752e-06 1.29099808e-05 3.50614660e-04 1.25202611e-02
 3.36117910e-06 1.68256903e+00 1.22273445e+00 2.89484352e-01
 4.50304411e-02 3.15849334e-01 1.64918676e-01 5.64996004e-02
 3.47735614e-01 6.68101566e-05 2.71118683e-04 8.02660465e-01
 2.99814008e-02 2.04067561e-03 2.16134288e-03 8.83975939e-04
 1.23258433e-04 7.13570307e-06 2.83779518e-04 1.65235233e+01], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray([2.39045306e-07, 5.05174361e-02, 3.54168625e+01,
             1.47604517e+03, 3.53433342e+01, 4.93178144e-02,
             4.55805633e-07, 1.40898908e-03, 3.44161719e-01,
             9.97807801e-05, 3.79991978e-01, 1.46471069e+03,
             2.83114685e+02, 1.46274438e+03, 3.71743321e-01,
             3.45970882e-04, 6.38958882e-05, 1.96450111e-03,
             1.36579215e-09, 4.94456030e-02, 3.51516914e+01,
             1.46444324e+03, 3.50581055e+01, 4.76084426e-02,
             3.62756909e-05, 2.59627323e-05, 1.28611634e-02,
             1.53637741e-10, 1.10259043e-06, 4.41823229e-02,
             3.57002020e-01, 2.58970540e-02, 4.03595768e-05,
             3.03160959e-06, 1.17222185e-03, 8.77917588e-01,
             1.21548219e-05, 1.10309920e-05, 9.27108247e-03,
             4.94957045e-02, 2.59197364e-03, 1.15707962e-06,
             1.18230097e-02, 3.12171178e-04, 3.32510471e-03,
             3.49332261e+00, 7.32741091e-06, 2.74387912e-06,
             1.01096899e-04, 5.29806968e-03, 3.85003159e-06,
             5.52617666e-06, 3.25750816e-03, 5.01227856e+00,
             2.41945744e-01, 3.02146364e-04, 9.41658072e-06,
             3.36052040e-07, 1.35825780e-06, 6.64325270e-08,
             1.76019839e-05, 1.08838431e-03, 1.62803262e-04,
             2.26547308e-02, 1.64264720e-03, 1.85137324e-04,
             1.56695936e-02, 7.15451062e-01, 2.22016979e-06,
             1.01327966e-03, 2.23005915e+00, 7.91126966e-01,
             5.04123306e+00, 4.45488766e-02, 2.39215661e-02,
             4.78328019e-02, 1.72335599e-02, 6.92157680e-03,
             2.44008159e-04, 1.83372939e+00, 1.95616283e+01,
             3.79712988e+03, 1.09777878e+02, 2.67002295e+03,
             1.35775498e+04, 2.65880493e+03, 1.10153786e+02,
             3.73467236e+03, 1.19229942e+02, 2.84601521e+00,
             6.83263672e+03, 2.79380798e+02, 1.34788926e+04,
             4.32935900e+06, 1.34695459e+04, 2.80177765e+02,
             6.64262061e+03, 2.02965302e+02, 3.41118026e+00,
             3.56081836e+03, 1.10421318e+02, 2.64797974e+03,
             1.34977988e+04, 2.64846924e+03, 1.10255013e+02,
             3.50496948e+03, 1.18950562e+02, 3.01873517e+00,
             6.62234680e+02, 4.82003857e+03, 1.09844521e+02,
             2.80528381e+02, 1.09617119e+02, 4.74906738e+03,
             6.55362976e+02, 2.70264378e+01, 8.97035313e+00,
             5.15718269e+01, 6.45551758e+02, 3.28150317e+03,
             5.16179639e+03, 3.30050366e+03, 6.39274292e+02,
             5.09981346e+01, 3.41962957e+00, 1.42685509e+00,
             1.01157246e+01, 2.33161469e+01, 1.10907883e+02,
             1.90045609e+02, 1.08419716e+02, 2.30633316e+01,
             2.79284310e+00, 1.17048889e-01, 1.94730873e+01,
             1.74634588e+00, 4.49147314e-01, 1.55767703e+00,
             2.50792360e+00, 1.58440232e+00, 4.48806077e-01,
             1.45511791e-01, 1.35346794e+00, 2.90576243e+00,
             2.38838717e-01, 4.18374613e-02, 3.52283823e-03,
             2.95951460e-02, 8.89666438e-01, 4.76098806e-02,
             1.34586644e+00, 4.76615868e+01, 1.10827850e+02,
             1.43388319e+01, 1.19029693e-01, 3.56732607e-02,
             5.71300238e-02, 3.18590216e-02, 8.93243849e-02,
             2.88715529e+00, 1.14536575e+02, 9.37667389e+01,
             1.30518221e-10, 1.29126704e-07, 9.43186460e-05,
             3.93137056e-03, 9.41418984e-05, 1.27508997e-07,
             6.05227999e-13, 3.23746319e-09, 5.00035000e+00,
             1.84663360e-13, 1.00962234e-06, 3.90111632e-03,
             7.08393578e-04, 3.89609602e-03, 9.93981416e-07,
             6.96998459e-09, 7.74617176e-11, 3.42199957e-09,
             1.13464117e-11, 1.30591786e-07, 9.36589786e-05,
             3.90049233e-03, 9.34074851e-05, 1.27706102e-07,
             5.01154728e-13, 1.34228934e-10, 3.50442271e-08,
             2.35068584e-12, 9.94621916e-15, 1.22035175e-07,
             9.51181676e-07, 9.61299804e-08, 1.11253324e-12,
             9.48410729e-13, 2.91648905e-09, 2.86607862e+00,
             4.37594676e-13, 6.09899013e-13, 1.51472002e-09,
             2.10111040e-07, 2.58265354e-10, 8.70264555e-11,
             2.08061257e-09, 2.24356020e-11, 5.20397186e+00,
             4.11031609e-09, 6.17094960e-12, 1.25937732e-11,
             1.82825005e-10, 3.73396247e-08, 2.50152806e-12,
             6.60852015e-13, 4.24275895e-05, 4.88129648e-10,
             6.55873775e+00, 2.35146031e-01, 3.65504176e-13,
             1.53000546e-11, 8.54772086e-11, 6.22795912e-12,
             1.76353077e-03, 2.91543256e-09, 4.86712635e-01,
             5.91308951e-01, 1.48966476e-01, 2.95371190e-02,
             2.63620436e-01, 9.36421826e-02, 3.79317403e-02,
             1.12129085e-01, 5.94634287e-07, 1.55835085e-06,
             9.47777967e-10, 1.59048790e-03, 1.17888238e-04,
             6.71944988e-04, 1.00038422e-04, 1.62846624e-14,
             3.18903293e-10, 4.51858523e-06, 4.73535110e-05,
             1.08632715e-02, 2.92142475e-04, 7.11128861e-03,
             3.61628532e-02, 7.08157336e-03, 2.93209887e-04,
             1.06833652e-02, 3.40949657e-04, 2.31939468e+01,
             1.95433274e-02, 7.44041230e-04, 3.58998738e-02,
             1.23857145e+01, 3.58756483e-02, 7.46334612e-04,
             1.90099515e-02, 5.80490974e-04, 9.69203757e-06,
             1.01864180e-02, 2.94046564e-04, 7.05297943e-03,
             3.59505154e-02, 7.05426000e-03, 2.93697551e-04,
             1.00254146e-02, 3.40347964e-04, 8.61463559e-06,
             1.89446309e-03, 1.37866689e-02, 2.92766665e-04,
             7.47172337e-04, 2.93093617e-04, 1.35827754e-02,
             1.87465397e-03, 7.72828862e-05, 1.25515757e+01,
             1.47473242e-04, 1.84648274e-03, 9.37683042e-03,
             1.47797344e-02, 9.43558663e-03, 1.82904524e-03,
             1.44434132e-04, 9.62995182e-06, 1.47838497e+01,
             8.86684507e-02, 6.66825217e-05, 3.17285390e-04,
             5.43577538e-04, 3.11621989e-04, 6.59603829e-05,
             7.97742086e-06, 3.25425416e-02, 2.38146409e-02,
             2.03193035e+01, 5.44075012e-01, 4.44855868e-06,
             7.18155752e-06, 1.29099808e-05, 3.50614660e-04,
             1.25202611e-02, 3.36117910e-06, 1.68256903e+00,
             1.22273445e+00, 2.89484352e-01, 4.50304411e-02,
             3.15849334e-01, 1.64918676e-01, 5.64996004e-02,
             3.47735614e-01, 6.68101566e-05, 2.71118683e-04,
             8.02660465e-01, 2.99814008e-02, 2.04067561e-03,
             2.16134288e-03, 8.83975939e-04, 1.23258433e-04,
             7.13570307e-06, 2.83779518e-04, 1.65235233e+01],            dtype=float32)
  tangent = Traced<ShapedArray(float32[324])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[324]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7ffb942f5200>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7ffb942fff40; to 'JaxprTracer' at 0x7ffb942ff9a0>], out_avals=[ShapedArray(float32[324])], primitive=xla_call, params={'device': None, 'backend': None, 'name': 'fn', 'donated_invars': (False, False), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324] b:f32[324]. let c:f32[324] = add a b in (c,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7ffb943034b0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

In [None]:
dt = 1
xtol = 1e-5
Tmax = 600
Nmax = (np.abs((np.round(Tmax/dt))))
np.array(Nmax, int)#

In [None]:
Nmax