# 2D SSN Model

1. Imports

In [1]:
import jax.numpy as np
import jax
import matplotlib.pyplot as plt
import time, os, json
import pandas as pd
from scipy import stats 
from tqdm import tqdm
import seaborn as sns
from jax import random
from jax.config import config 
import pdb
import optax

#config.update('jax_debug_nans', True)
from SSN_classes_jax import SSN2DTopoV1_AMPAGABA_ONOFF
from util import GaborFilter, BW_Grating, find_A, create_gabor_filters, create_gratings

# MODEL TRAINING

## --> INCLUDE NUMBER OF TRIALS

In [2]:
def sigmoid(x):
    return 1/(1+np.exp(x))

def binary_loss(n, x):
    return - (n*np.log(x) + (1-n)*np.log(1-x))

def model(opt_pars, ssn_pars, grid_pars, conn_pars, train_data, labels, filter_pars,  **conv_pars):
    
    signs=np.array([[1, -1], [1, -1]])
    
    J_2x2 =np.exp(opt_pars['logJ_2x2'])*signs
    s_2x2 = np.exp(opt_pars['logs_2x2'])
    
    total_loss=0
    
    for i in range(len(train_data)):
        #Initialise network
        ssn=SSN2DTopoV1_AMPAGABA_ONOFF(ssn_pars=ssn_pars, grid_pars=grid_pars, conn_pars=conn_pars, filter_pars=filter_pars, J_2x2=J_2x2, s_2x2=s_2x2)
                                                   
        #Apply Gabor filters to stimuli
        output_ref=np.matmul(ssn.gabor_filters, train_data[i, 0].ravel())*ssn.A
        output_target=np.matmul(ssn.gabor_filters, train_data[i,1].ravel())*ssn.A

        #Rectify output
        SSN_input_ref=np.maximum(0, output_ref)
        SSN_input_target=np.maximum(0, output_target)
        
        #Input to SSN
        r_init = np.zeros(SSN_input_ref.shape[0])
        
        fp_ref, _ = ssn.fixed_point_r(SSN_input_ref, r_init=r_init, **conv_pars)
        x_ref = ssn.apply_bounding_box(fp_ref, size=3.2)

        fp_target, _ = ssn.fixed_point_r(SSN_input_target, r_init=r_init, **conv_pars)
        x_target = ssn.apply_bounding_box(fp_target, size=3.2)

        #Apply sigmoid function - combine ref and target
        x = sigmoid( np.dot(opt_pars['w_sig'], (x_ref.ravel() - x_target.ravel())) + opt_pars['b_sig'])

        #Calculate binary cross entropy loss
        total_loss+=np.sum(binary_loss(labels[i], x))
        
    print('Loss {}'.format(total_loss))
   
    return total_loss
    

def train_SSN(opt_pars, ssn_pars, grid_pars, conn_pars, train_data, labels, filter_pars, conv_pars, epochs=1, eta=10e-4):
    
    #Initialize loss
    loss_per_epoch = []
    
    #Initialise optimizer
    optimizer = optax.adam(eta)
    opt_state = optimizer.init(opt_pars)
    
    for epoch in range(epochs):
        
        
        #compute loss and gradient 
        loss, grad =jax.value_and_grad(model)(opt_pars, ssn_pars, grid_pars, conn_pars, train_data, labels, filter_pars,  **conv_pars)
        
        print('Loss {} at epoch {}'.format(loss, epoch+1))
        
        #Apply SGD through Adam optimizer
        updates, opt_state = optimizer.update(grad, opt_state)
        opt_pars = optax.apply_updates(opt_pars, updates)
        
        loss_per_epoch.append(loss)
    
    #reparametize parameters
    signs=np.array([[1, -1], [1, -1]])    
    opt_pars['logJ_2x2'] = np.exp(opt_pars['logJ_2x2'])*signs
    opt_pars['logs_2x2'] = np.exp(opt_pars['logs_2x2'])
    
    return opt_pars, loss_per_epoch

1. Define parameters

In [7]:
#Convergence parameters

#Network parameters
class ssn_pars():
    n = 2
    k = 0.04
    tauE = 30 # in ms
    tauI = 10 # in ms
    psi = 0.774
    tau_s = np.array([5, 7, 100]) #in ms, AMPA, GABA, NMDA current decay time constants
    

#Grid parameters
class grid_pars():
    gridsize_Nx = 9 # grid-points across each edge # gives rise to dx = 0.8 mm
    gridsize_deg = 2 * 1.6 # edge length in degrees
    magnif_factor = 2  # mm/deg
    hyper_col = 0.8 # mm   
    sigma_RF = 0.4 # deg (visual angle)

# Caleb's params for the full (with local) model:
Js0 = [1.82650658, 0.68194475, 2.06815311, 0.5106321]
gE, gI = 0.57328625, 0.26144141

sigEE, sigIE = 0.2, 0.40
sigEI, sigII = .09, .09
conn_pars = dict(
    PERIODIC = False,
    p_local = [.4, 0.7], # [p_local_EE, p_local_IE],
    sigma_oris = 1000) # sigma_oris


make_J2x2 = lambda Jee, Jei, Jie, Jii: np.array([[Jee, -Jei], [Jie,  -Jii]]) * np.pi * ssn_pars.psi
J_2x2 = make_J2x2(*Js0)
s_2x2 = np.array([[sigEE, sigEI],[sigIE, sigII]])

#Positive reparameterization
signs=np.array([[1, -1], [1, -1]])
logJ_2x2 =np.log(J_2x2*signs)
logs_2x2 = np.log(s_2x2)


#Sigmoid parameters
N_neurons = 25
key = random.PRNGKey(10)
w_sig = random.normal(key, shape = (N_neurons,)) / np.sqrt(N_neurons)
b_sig = 0.0

#Gabor parameters 
sigma_g= 0.5
k = np.pi/(6*sigma_g)

#Optimization pars
opt_pars = dict(logJ_2x2 = logJ_2x2, logs_2x2 = logs_2x2, w_sig = w_sig, b_sig=b_sig)

#Parameters shared with stimuli
general_pars = dict(k=k , edge_deg=3.2,  degree_per_pixel=0.05)

#Parameters exclusive to Gabor filters
filter_pars = dict(sigma_g = sigma_g, conv_factor = grid_pars.magnif_factor)
filter_pars.update(general_pars) 

2. Create training data

In [11]:
#Stimuli parameters
stimuli_pars = dict(outer_radius=3, inner_radius=2.5, grating_contrast=0.99, snr = 0.9)
stimuli_pars.update(general_pars)

#Create gratings at given orientation and list of labels
gratings, labels = create_gratings(ref_ori=55, number=100, offset= 40, jitter_val=0, **stimuli_pars)

3. TRAINING!

In [38]:
conv_pars=dict(dt = 1, xtol = 1e-5, Tmax = 200, verbose=False, silent=True)

#model_test = model(opt_pars, ssn_pars, grid_pars, conn_pars, gratings, labels, filter_pars,  **conv_pars)
new_pars, loss= train_SSN(opt_pars, ssn_pars, grid_pars, conn_pars, gratings, labels, filter_pars, conv_pars, epochs=1)

Traced<ConcreteArray(301.4615173339844, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(301.46152, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89a840e9f0>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f89c80a4720; to 'JaxprTracer' at 0x7f89c80a4220>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f897456e470>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(300.1659851074219, dtype=float32)>with<JVPTrace(level=2/0)> with
  p

Traced<ConcreteArray(-1.0381773710250854, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(-1.0381774, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89a8367820>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f890dcd19a0; to 'JaxprTracer' at 0x7f890dcd1d60>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f89a040b4f0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(-0.9167177081108093, dtype=float32)>with<JVPTrace(level=2/0)> wit

Traced<ConcreteArray(0.16371923685073853, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.16371924, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89a866a130>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f8984549040; to 'JaxprTracer' at 0x7f8984549130>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f89a85a7a70>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.15996681153774261, dtype=float32)>with<JVPTrace(level=2/0)> wit

Traced<ConcreteArray(0.06253541260957718, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.06253541, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89c824fc50>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f89a8225360; to 'JaxprTracer' at 0x7f89a8225040>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f89a046e670>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.060808055102825165, dtype=float32)>with<JVPTrace(level=2/0)> wi

Traced<ConcreteArray(0.02530558407306671, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.02530558, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89ec6eb130>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f890c10c0e0; to 'JaxprTracer' at 0x7f890c10c220>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f89a8536870>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.02463381551206112, dtype=float32)>with<JVPTrace(level=2/0)> wit

Traced<ConcreteArray(151.00173950195312, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(151.00174, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89a81a4ea0>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f890c00adb0; to 'JaxprTracer' at 0x7f890c00a860>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f89845d6870>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(81.77886962890625, dtype=float32)>with<JVPTrace(level=2/0)> with
  

Traced<ConcreteArray(-0.6192036271095276, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(-0.6192036, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89a053d930>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f88ff2126d0; to 'JaxprTracer' at 0x7f88ff2125e0>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f89a8622430>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(-0.5304888486862183, dtype=float32)>with<JVPTrace(level=2/0)> wit

Traced<ConcreteArray(0.12668010592460632, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.1266801, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f8984635420>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f89a05c4810; to 'JaxprTracer' at 0x7f89a05c4270>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f890df8c4b0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.12345685064792633, dtype=float32)>with<JVPTrace(level=2/0)> with

Traced<ConcreteArray(0.04836535081267357, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.04836535, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f890cd56440>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f8984492770; to 'JaxprTracer' at 0x7f8984492ae0>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f89845f7870>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.04700329527258873, dtype=float32)>with<JVPTrace(level=2/0)> wit

Traced<ConcreteArray(0.019113190472126007, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.01911319, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89a0450bd0>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f89a81da720; to 'JaxprTracer' at 0x7f89a81da180>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f89a82f6530>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.018641121685504913, dtype=float32)>with<JVPTrace(level=2/0)> w

Traced<ConcreteArray(281.21026611328125, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(281.21027, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89245b78e0>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f898478fcc0; to 'JaxprTracer' at 0x7f898478fae0>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f89a83b4e70>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(280.8089599609375, dtype=float32)>with<JVPTrace(level=2/0)> with
  

Traced<ConcreteArray(-0.4307332932949066, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(-0.4307333, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f890dd4a170>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f892413cb80; to 'JaxprTracer' at 0x7f892413c900>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f890dd52970>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(-0.35517072677612305, dtype=float32)>with<JVPTrace(level=2/0)> wi

Traced<ConcreteArray(0.15480941534042358, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.15480942, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89a862c170>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f890c73b860; to 'JaxprTracer' at 0x7f89a04692c0>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f89a86014f0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.15091755986213684, dtype=float32)>with<JVPTrace(level=2/0)> wit

Traced<ConcreteArray(0.06128125637769699, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.06128126, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89a8537780>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f89a0699ea0; to 'JaxprTracer' at 0x7f89a06990e0>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f890cc43a30>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.05971139669418335, dtype=float32)>with<JVPTrace(level=2/0)> wit

Traced<ConcreteArray(0.024161962792277336, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.02416196, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89241d6f70>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f890cd64270; to 'JaxprTracer' at 0x7f890d766310>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f89845d7a30>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.02361137419939041, dtype=float32)>with<JVPTrace(level=2/0)> wi

Traced<ConcreteArray(-8.669682502746582, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(-8.6696825, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f890da85ca0>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f890d6ff680; to 'JaxprTracer' at 0x7f890d6ff6d0>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f8974559570>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(-8.249068260192871, dtype=float32)>with<JVPTrace(level=2/0)> with


Traced<ConcreteArray(0.25791817903518677, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.25791818, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89844adc80>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f890da9fe00; to 'JaxprTracer' at 0x7f890da9f770>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f8984203830>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.2612846791744232, dtype=float32)>with<JVPTrace(level=2/0)> with

Traced<ConcreteArray(0.11434487998485565, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.11434488, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f890d0f06b0>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f890c4d8180; to 'JaxprTracer' at 0x7f890c4d84f0>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f89845b7b70>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.11105689406394958, dtype=float32)>with<JVPTrace(level=2/0)> wit

Traced<ConcreteArray(0.04549453407526016, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.04549453, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f898473bb30>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f89743c5810; to 'JaxprTracer' at 0x7f89743c5720>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f89a033b1b0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.04424654692411423, dtype=float32)>with<JVPTrace(level=2/0)> wit

Traced<ConcreteArray(0.015554627403616905, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.01555463, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f890d6ae5a0>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f89846d29f0; to 'JaxprTracer' at 0x7f89846d2180>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f890d6b73b0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.015172120183706284, dtype=float32)>with<JVPTrace(level=2/0)> w

Traced<ConcreteArray(346.04083251953125, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(346.04083, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f890cf9c120>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f88ff1ff900; to 'JaxprTracer' at 0x7f88ff1ff130>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f890cfaa4b0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(344.6546325683594, dtype=float32)>with<JVPTrace(level=2/0)> with
  

Traced<ConcreteArray(-2.643928050994873, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(-2.643928, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f890d756b90>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f890d2609f0; to 'JaxprTracer' at 0x7f892473c630>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f890df78970>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(-2.3866705894470215, dtype=float32)>with<JVPTrace(level=2/0)> with


Traced<ConcreteArray(0.16589730978012085, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.16589731, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f890df97310>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f88ff175590; to 'JaxprTracer' at 0x7f88ff1754f0>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f88ff1abe30>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.1613922119140625, dtype=float32)>with<JVPTrace(level=2/0)> with

Traced<ConcreteArray(0.062004320323467255, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.06200432, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89a052c790>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f88ff20c270; to 'JaxprTracer' at 0x7f88ff219310>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f88ff1fd5b0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.06015714630484581, dtype=float32)>with<JVPTrace(level=2/0)> wi

Traced<ConcreteArray(0.022786390036344528, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.02278639, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89a04c2d00>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f890cd84860; to 'JaxprTracer' at 0x7f890cd84450>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f890c9702b0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.02216007187962532, dtype=float32)>with<JVPTrace(level=2/0)> wi

Traced<ConcreteArray(41.25323486328125, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(41.253235, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f892415d390>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f8974283180; to 'JaxprTracer' at 0x7f8974283590>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f8974255eb0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(25.927730560302734, dtype=float32)>with<JVPTrace(level=2/0)> with
  

Traced<ConcreteArray(-0.3960622549057007, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(-0.39606225, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89840ba170>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f89840af900; to 'JaxprTracer' at 0x7f89840af590>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f8984096af0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(-0.3255399465560913, dtype=float32)>with<JVPTrace(level=2/0)> wi

Traced<ConcreteArray(0.15103644132614136, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.15103644, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89a0498050>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f89a048a130; to 'JaxprTracer' at 0x7f89a048a4a0>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f89a049cc30>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.1473565548658371, dtype=float32)>with<JVPTrace(level=2/0)> with

Traced<ConcreteArray(0.06170862913131714, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.06170863, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89a81875a0>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f892414c0e0; to 'JaxprTracer' at 0x7f89241448b0>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f8924130d30>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.060125917196273804, dtype=float32)>with<JVPTrace(level=2/0)> wi

Traced<ConcreteArray(0.025392454117536545, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.02539245, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f8984518cc0>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f8984377450; to 'JaxprTracer' at 0x7f892414c860>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f898437ad30>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.024812666699290276, dtype=float32)>with<JVPTrace(level=2/0)> w

Traced<ConcreteArray(285.283935546875, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(285.28394, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f890dead270>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f890c7c5ae0; to 'JaxprTracer' at 0x7f890c7c5630>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f890c7ff4f0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(284.85302734375, dtype=float32)>with<JVPTrace(level=2/0)> with
  prim

Traced<ConcreteArray(-0.6734197735786438, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(-0.6734198, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f890c9e3240>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f890c9e4770; to 'JaxprTracer' at 0x7f890c9e44a0>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f890c9e61f0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(-0.5699976682662964, dtype=float32)>with<JVPTrace(level=2/0)> wit

Traced<ConcreteArray(0.17125239968299866, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.1712524, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89742f2370>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f89742ce400; to 'JaxprTracer' at 0x7f89742ce4a0>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f89742fd870>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.1668526828289032, dtype=float32)>with<JVPTrace(level=2/0)> with


Traced<ConcreteArray(0.06786585599184036, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.06786586, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f890c9ab940>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f88ff100680; to 'JaxprTracer' at 0x7f88ff100630>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f88ff10bb70>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.06607255339622498, dtype=float32)>with<JVPTrace(level=2/0)> wit

Traced<ConcreteArray(0.028296321630477905, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.02829632, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f890d186df0>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f89c80b5680; to 'JaxprTracer' at 0x7f89c80b5270>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f89c80d1b30>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.02762288972735405, dtype=float32)>with<JVPTrace(level=2/0)> wi

Traced<ConcreteArray(21.383275985717773, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(21.383276, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89846a05c0>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f890d4c67c0; to 'JaxprTracer' at 0x7f890d4c6400>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f890d4d7fb0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(11.66224193572998, dtype=float32)>with<JVPTrace(level=2/0)> with
  

Traced<ConcreteArray(-0.22811883687973022, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(-0.22811884, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89a062f420>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f890d5be720; to 'JaxprTracer' at 0x7f890d5be4f0>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f890d5b1130>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(-0.1631789207458496, dtype=float32)>with<JVPTrace(level=2/0)> w

Traced<ConcreteArray(0.20413771271705627, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.20413771, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f88ff03e3d0>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f89a0628680; to 'JaxprTracer' at 0x7f88ff0660e0>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f88ff0419b0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.19923925399780273, dtype=float32)>with<JVPTrace(level=2/0)> wit

Traced<ConcreteArray(0.08478015661239624, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.08478016, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89a035c8d0>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f890cb75220; to 'JaxprTracer' at 0x7f890cb75630>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f890cb44970>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.08259086310863495, dtype=float32)>with<JVPTrace(level=2/0)> wit

Traced<ConcreteArray(0.03472917526960373, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.03472918, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f89a066cf50>, in_tracers=(Traced<ShapedArray(float32[324]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7f890d390220; to 'JaxprTracer' at 0x7f890d3909f0>], out_avals=[ShapedArray(float32[])], primitive=xla_call, params={'device': None, 'backend': None, 'name': '_reduce_sum', 'donated_invars': (False,), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[324]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f890d3a9730>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ConcreteArray(0.03387674316763878, dtype=float32)>with<JVPTrace(level=2/0)> wit

Loss 1.358145833015442 at epoch 1


In [15]:
ssn=SSN2DTopoV1_AMPAGABA_ONOFF(ssn_pars=ssn_pars, grid_pars=grid_pars, conn_pars=conn_pars, filter_pars=filter_pars, J_2x2=J_2x2, s_2x2=s_2x2)

#Apply Gabor filters to stimuli
output_ref=np.matmul(ssn.gabor_filters, gratings[1, 0].ravel())*ssn.A
output_target=np.matmul(ssn.gabor_filters, gratings[1,1].ravel())*ssn.A

#Rectify output
SSN_input_ref=np.maximum(0, output_ref)
SSN_input_target=np.maximum(0, output_target)

#Input to SSN
r_init = np.zeros(SSN_input_ref.shape[0])

fp_ref, _ = ssn.fixed_point_r(SSN_input_ref, r_init=r_init, **conv_pars)

#fp_target, _ = ssn.fixed_point_r(SSN_input_target, r_init=r_init, **conv_pars)


KeyboardInterrupt



In [8]:
np.sum(ssn.W)

DeviceArray(2128.79824252, dtype=float64)