# 2D SSN Model

1. Imports

In [1]:
import jax.numpy as np
import jax
import matplotlib.pyplot as plt
import time, os, json
import pandas as pd
from scipy import stats 
from tqdm import tqdm
import seaborn as sns
from jax import random

from SSN_classes_jax import SSN2DTopoV1_AMPAGABA_ONOFF
from util import GaborFilter, BW_Grating, find_A, create_gabor_filters, create_gratings

2. Create SSN network

In [None]:
#Network parameters
class ssn_pars():
    n = 2
    k = 0.04
    tauE = 30 # in ms
    tauI = 10 # in ms
    psi = 0.774
    tau_s = np.array([5, 7, 100]) #in ms, AMPA, GABA, NMDA current decay time constants
    

#Grid parameters
class grid_pars():
    gridsize_Nx = 9 # grid-points across each edge # gives rise to dx = 0.8 mm
    gridsize_deg = 2 * 1.6 # edge length in degrees
    magnif_factor = 2  # mm/deg
    hyper_col = 0.8 # mm   
    sigma_RF = 0.4 # deg (visual angle)

# Caleb's params for the full (with local) model:
Js0 = [1.82650658, 0.68194475, 2.06815311, 0.5106321]
gE, gI = 0.57328625, 0.26144141

sigEE, sigIE = 0.2, 0.40
sigEI, sigII = .09, .09

conn_pars = dict(
    PERIODIC = False,
    p_local = [.4, 0.7], # [p_local_EE, p_local_IE],
    sigma_oris = 1000) # sigma_oris



make_J2x2 = lambda Jee, Jei, Jie, Jii: np.array([[Jee, -Jei], [Jie,  -Jii]]) * np.pi * ssn_pars.psi
J_2x2 = make_J2x2(*Js0)
s_2x2 = np.array([[sigEE, sigEI],[sigIE, sigII]])

#Create network
ssn = SSN2DTopoV1_AMPAGABA_ONOFF(ssn_pars, grid_pars, conn_pars=conn_pars, J_2x2=J_2x2, s_2x2=s_2x2)


3. Create Gabor filters

In [None]:
#Gabor parameters 
sigma_g= 0.5
k= np.pi/(6*sigma_g)
general_pars = dict(k=k, edge_deg=3.2,  degree_per_pixel=0.05) #parameters shared with input stimuli


#Create filters
SSN_filters, A =create_gabor_filters(ssn, sigma_g=sigma_g, conv_factor = grid_pars.magnif_factor, **general_pars)

3. Input target and reference

In [None]:
#Stimuli parameters
stimuli_pars = dict(outer_radius=3, inner_radius=2.5, grating_contrast=0.99)
stimuli_pars.update(general_pars)

#Create reference stimuli
ori_ref = 0
ref_grating=BW_Grating(ori_deg = ori_ref, **stimuli_pars).BW_image()

#Create target stimuli
ori_target= 10
target_grating=BW_Grating(ori_deg = ori_target, **stimuli_pars).BW_image()

# MODEL TRAINING

In [None]:
def sigmoid(x):
    return 1/(1+np.exp(-x))

def binary_loss(n, x):
    return - n*np.log(x) - (1-n)*np.log(1-x)

def model(J_2x2, s_2x2, ssn_pars, grid_pars, conn_pars, general_pars, label, ref, target, **conv_pars):
    
    #Initialise network
    ssn=SSN2DTopoV1_AMPAGABA_ONOFF(ssn_pars, grid_pars, conn_pars=conn_pars, J_2x2=J_2x2, s_2x2=s_2x2)
   
    #Create Gabor filters for network
    SSN_filters, A =create_gabor_filters(ssn, sigma_g=sigma_g, conv_factor = grid_pars.magnif_factor, **general_pars)                               
                                       
    #Apply Gabor filters to stimuli
    output_ref=np.matmul(SSN_filters, ref.ravel())*A
    output_target=np.matmul(SSN_filters, target.ravel())*A
    
    #Rectify output
    SSN_input_ref=np.maximum(0, output_ref)
    SSN_input_target=np.maximum(0, output_target)
    
    #Input to SSN
    r_init = np.zeros(SSN_input_ref.shape[0])
    
    fp_ref, _ = ssn.fixed_point_r(SSN_input_ref, r_init=r_init, **conv_pars)
    
    fp_target, _ = ssn.fixed_point_r(SSN_input_target, r_init=r_init, **conv_pars)
    
    #Combine reference and target 
    x = fp_ref + fp_target
    
    #Apply sigmoid function
    x = sigmoid(0.1*x)
    
    #Calculate binary cross entropy loss
    loss=np.sum(binary_loss(label, x))
   
    #check what indices are nan
    #indices = np.argwhere(np.isnan(loss))
    
    return loss
    

def train_SSN(J_2x2, s_2x2, ssn_pars, grid_pars, conn_pars, general_pars, train_data, stimuli_pars, conv_pars):
    
    #Initialize loss
    #total_loss = 0
           
    #find label
    if train_data[0] > train_data[1]:
        label=1
    else:
        label=0
        
    ref_grating=BW_Grating(ori_deg = train_data[0], **stimuli_pars).BW_image()
    target_grating = BW_Grating(ori_deg = train_data[1], **stimuli_pars).BW_image()
        
        #total_loss+=model(ssn, A, SSN_filters, label, ref_grating, target_grating, **conv_pars )
    
    grad_loss=jax.grad(model, argnums=(0,1))
    dJ, ds=grad_loss(J_2x2, s_2x2, ssn_pars, grid_pars, conn_pars, general_pars, label, ref_grating, target_grating, **conv_pars)
    
    #compute loss
    #model_out=model(J_2x2, s_2x2, ssn_pars, grid_pars, conn_pars, general_pars, label, ref_grating, target_grating, **conv_pars)
    
    #update parameters using gradient descent
    J_2x2 = J_2x2 - dJ
    s_2x2 = s2_2x2 - ds
    
    
    return gradient
   

In [None]:
#fp_ref, _ = ssn.fixed_point_r(SSN_input_ref, r_init=r_init, **conv_pars)

train_data=train_data=[10,20]

J, s= train_SSN(J_2x2, s_2x2, ssn_pars, grid_pars, conn_pars, general_pars, train_data, stimuli_pars, conv_pars)

## --> INCLUDE NUMBER OF TRIALS

In [2]:
def sigmoid(x):
    return 1/(1+np.exp(-x))

def binary_loss(n, x):
    return - n*np.log(x) - (1-n)*np.log(1-x)

def model(J_2x2, s_2x2, ssn_pars, grid_pars, conn_pars, general_pars, train_data, labels, **conv_pars):
    
    total_loss=0
    for i in range(len(training_data)):
        
        #Initialise network
        ssn=SSN2DTopoV1_AMPAGABA_ONOFF(ssn_pars, grid_pars, conn_pars=conn_pars, J_2x2=J_2x2, s_2x2=s_2x2)

        #Create Gabor filters for network
        SSN_filters, A =create_gabor_filters(ssn, sigma_g=sigma_g, conv_factor = grid_pars.magnif_factor, **general_pars)                               
        
        
        #Apply Gabor filters to stimuli
        output_ref=np.matmul(SSN_filters, train_data[i, 0].ravel())*A
        output_target=np.matmul(SSN_filters, train_data[i,1].ravel())*A

        #Rectify output
        SSN_input_ref=np.maximum(0, output_ref)
        SSN_input_target=np.maximum(0, output_target)

        #Input to SSN
        r_init = np.zeros(SSN_input_ref.shape[0])

        fp_ref, _ = ssn.fixed_point_r(SSN_input_ref, r_init=r_init, **conv_pars)

        fp_target, _ = ssn.fixed_point_r(SSN_input_target, r_init=r_init, **conv_pars)

        #Combine reference and target 
        x = fp_ref + fp_target

        #Apply sigmoid function
        x = sigmoid(0.1*x)

        #Calculate binary cross entropy loss
        total_loss+=np.sum(binary_loss(labels[i], x))
   
    #check what indices are nan
    #indices = np.argwhere(np.isnan(loss))
    
    return total_loss
    

def train_SSN(J_2x2, s_2x2, ssn_pars, grid_pars, conn_pars, general_pars, train_data, labels, stimuli_pars, conv_pars, epochs=2):
    
    #Initialize loss
    #total_loss = 0
    
    for i in range(epochs):
        

        grad_loss=jax.grad(model, argnums=(0,1))
        dJ, ds=grad_loss(J_2x2, s_2x2, ssn_pars, grid_pars, conn_pars, general_pars, train_data, labels, **conv_pars)

        #compute loss
        #model_out=model(J_2x2, s_2x2, ssn_pars, grid_pars, conn_pars, general_pars, label, ref_grating, target_grating, **conv_pars)

        #update parameters using gradient descent
        J_2x2 = J_2x2 - dJ
        s_2x2 = s_2x2 - ds
    
    
    return J_2x2, s_2x2
   

1. Define parameters

In [3]:
#Convergence parameters
conv_pars=dict(dt = 1, xtol = 1e-5, Tmax = 600)

#Network parameters
class ssn_pars():
    n = 2
    k = 0.04
    tauE = 30 # in ms
    tauI = 10 # in ms
    psi = 0.774
    tau_s = np.array([5, 7, 100]) #in ms, AMPA, GABA, NMDA current decay time constants
    

#Grid parameters
class grid_pars():
    gridsize_Nx = 9 # grid-points across each edge # gives rise to dx = 0.8 mm
    gridsize_deg = 2 * 1.6 # edge length in degrees
    magnif_factor = 2  # mm/deg
    hyper_col = 0.8 # mm   
    sigma_RF = 0.4 # deg (visual angle)


    # Caleb's params for the full (with local) model:
Js0 = [1.82650658, 0.68194475, 2.06815311, 0.5106321]
gE, gI = 0.57328625, 0.26144141

sigEE, sigIE = 0.2, 0.40
sigEI, sigII = .09, .09

conn_pars = dict(
    PERIODIC = False,
    p_local = [.4, 0.7], # [p_local_EE, p_local_IE],
    sigma_oris = 1000) # sigma_oris



make_J2x2 = lambda Jee, Jei, Jie, Jii: np.array([[Jee, -Jei], [Jie,  -Jii]]) * np.pi * ssn_pars.psi
J_2x2 = make_J2x2(*Js0)
s_2x2 = np.array([[sigEE, sigEI],[sigIE, sigII]])

#Create network
ssn = SSN2DTopoV1_AMPAGABA_ONOFF(ssn_pars, grid_pars, conn_pars=conn_pars, J_2x2=J_2x2, s_2x2=s_2x2)


#Gabor parameters 
sigma_g= 0.5
k= np.pi/(6*sigma_g)
general_pars = dict(k=k, edge_deg=3.2,  degree_per_pixel=0.05) #parameters shared with input stimuli

stimuli_pars = dict(outer_radius=3, inner_radius=2.5, grating_contrast=0.99)
stimuli_pars.update(general_pars)



2. Create training data

In [4]:
#Create training data
training_data=np.array(([10, 20], [10,22], [10, 21]))
gratings, labels = create_gratings(training_data, **stimuli_pars)

3. TRAINING!

In [5]:
J, s= train_SSN(J_2x2, s_2x2, ssn_pars, grid_pars, conn_pars, general_pars, gratings,labels, stimuli_pars, conv_pars)

Average A is 0.0008299059715404468

       max(abs(dx./max(abs(xvec), 1.0))) = Traced<ConcreteArray(0.1498628854751587, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.14986289, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7fdf1846c740>, in_tracers=(Traced<ConcreteArray([0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0

Average A is 0.0008299059715404468

       max(abs(dx./max(abs(xvec), 1.0))) = Traced<ConcreteArray(0.1498628854751587, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.14986289, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7fdea0b581f0>, in_tracers=(Traced<ConcreteArray([0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0

Average A is 0.0008299059715404468

       max(abs(dx./max(abs(xvec), 1.0))) = Traced<ConcreteArray(nan, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(nan, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7fdea0465a20>, in_tracers=(Traced<ConcreteArray([0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0