# 2D SSN Model

Information on training vmap:
- https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html
- https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html

Imports

In [51]:
import matplotlib.pyplot as plt
import time, os, json
import pandas as pd
from scipy import stats 
from tqdm import tqdm
import seaborn as sns
import jax
from jax import random
from jax.config import config 
import jax.numpy as np
from jax import vmap
import pdb
import optax
import time
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader


#config.update('jax_debug_nans', True)
from SSN_classes_jax import SSN2DTopoV1_AMPAGABA_ONOFF
from util import GaborFilter, BW_Grating, find_A, create_gabor_filters, create_gratings

## GENERATE DATA

In [74]:
#Gabor parameters 
sigma_g= 0.5
k = np.pi/(6*sigma_g)

#Parameters shared with stimuli
general_pars = dict(k=k , edge_deg=3.2,  degree_per_pixel=0.05)

#Stimuli parameters
stimuli_pars = dict(outer_radius=3, inner_radius=2.5, grating_contrast=0.5, snr = 0.9)
stimuli_pars.update(general_pars)

#Create gratings at given or ientation and list of labels
data = create_gratings(ref_ori=55, number=250, offset= 0.5, jitter_val=5, **stimuli_pars)
train, test = train_test_split(data, test_size = 0.2)
len(train), len(test)

(200, 50)

In [75]:
#batch params 
batch_size= 20
train_dataloader =DataLoader(train, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test, batch_size=len(test), shuffle=False)
number_batches= int(len(train) / batch_size)
number_batches

10

# MODEL TRAINING

1. Define parameters

In [66]:
#Network parameters
class ssn_pars():
    n = 2
    k = 0.04
    tauE = 30 # in ms
    tauI = 10 # in ms
    psi = 0.774
    tau_s = np.array([5, 7, 100]) #in ms, AMPA, GABA, NMDA current decay time constants
    

#Grid parameters
class grid_pars():
    gridsize_Nx = 9 # grid-points across each edge # gives rise to dx = 0.8 mm
    gridsize_deg = 2 * 1.6 # edge length in degrees
    magnif_factor = 2  # mm/deg
    hyper_col = 0.8 # mm   
    sigma_RF = 0.4 # deg (visual angle)

# Caleb's params for the full (with local) model:
Js0 = [1.82650658, 0.68194475, 2.06815311, 0.5106321]
gE, gI = 0.57328625, 0.26144141

sigEE, sigIE = 0.2, 0.40
sigEI, sigII = .09, .09
conn_pars = dict(
    PERIODIC = False,
    p_local = [.4, 0.7], # [p_local_EE, p_local_IE],
    sigma_oris = 1000) # sigma_oris


make_J2x2 = lambda Jee, Jei, Jie, Jii: np.array([[Jee, -Jei], [Jie,  -Jii]]) * np.pi * ssn_pars.psi
J_2x2 = make_J2x2(*Js0)
s_2x2 = np.array([[sigEE, sigEI],[sigIE, sigII]])

#Positive reparameterization
signs=np.array([[1, -1], [1, -1]])
logJ_2x2 =np.log(J_2x2*signs)
logs_2x2 = np.log(s_2x2)


#Sigmoid parameters
N_neurons = 25
key = random.PRNGKey(60)
w_sig = random.normal(key, shape = (N_neurons,)) / np.sqrt(N_neurons)
b_sig = 0.0


#Optimization pars
opt_pars = dict(logJ_2x2 = logJ_2x2, logs_2x2 = logs_2x2, w_sig = w_sig, b_sig=b_sig)


#Parameters exclusive to Gabor filters
filter_pars = dict(sigma_g = sigma_g, conv_factor = grid_pars.magnif_factor)
filter_pars.update(general_pars) 

#Convergence parameters
conv_pars=dict(dt = 1, xtol = 1e-5, Tmax = 200, verbose=False, silent=True)

3. TRAINING!

In [13]:
def sigmoid(x):
    return 1/(1+np.exp(x))

def binary_loss(n, x):
    return - (n*np.log(x) + (1-n)*np.log(1-x))

def model(opt_pars, ssn_pars, grid_pars, conn_pars, data, filter_pars,  conv_pars):
    
    signs=np.array([[1, -1], [1, -1]])
    
    J_2x2 =np.exp(opt_pars['logJ_2x2'])*signs
    s_2x2 = np.exp(opt_pars['logs_2x2'])
    
    total_loss=0
    
    for i in range(len(data['ref'])):
        #Initialise network
        ssn=SSN2DTopoV1_AMPAGABA_ONOFF(ssn_pars=ssn_pars, grid_pars=grid_pars, conn_pars=conn_pars, filter_pars=filter_pars, J_2x2=J_2x2, s_2x2=s_2x2)
                                                   
        #Apply Gabor filters to stimuli
        output_ref=np.matmul(ssn.gabor_filters, data['ref'][i])*ssn.A
        output_target=np.matmul(ssn.gabor_filters, data['target'][i])*ssn.A

        #Rectify output
        SSN_input_ref=np.maximum(0, output_ref)
        SSN_input_target=np.maximum(0, output_target)
        
        #Input to SSN
        r_init = np.zeros(SSN_input_ref.shape[0])
        
        fp_ref, _ = ssn.fixed_point_r(SSN_input_ref, r_init=r_init, **conv_pars)
        x_ref = ssn.apply_bounding_box(fp_ref, size=3.2)

        fp_target, _ = ssn.fixed_point_r(SSN_input_target, r_init=r_init, **conv_pars)
        x_target = ssn.apply_bounding_box(fp_target, size=3.2)

        #Apply sigmoid function - combine ref and target
        x = sigmoid( np.dot(opt_pars['w_sig'], (x_ref.ravel() - x_target.ravel())) + opt_pars['b_sig'])

        #Calculate binary cross entropy loss
        total_loss+=np.sum(binary_loss(data['label'][i], x))
        
   
    return total_loss
    

def train_SSN(opt_pars, ssn_pars, grid_pars, conn_pars, train_dataloader, test_dataloader, filter_pars, conv_pars, batches, epochs=1, eta=10e-4):
    
    #Initialize loss
    val_loss_per_epoch = []
    
    #Initialise optimizer
    optimizer = optax.adam(eta)
    opt_state = optimizer.init(opt_pars)
    
    #Define test data - no need to iterate
    test_iterator = iter(test_dataloader)
    test_data = next(test_iterator)
    test_data['ref'] = test_data['ref'].numpy()
    test_data['target'] = test_data['target'].numpy()
    test_data['label'] = test_data['label'].numpy()
    
    for epoch in range(epochs):
        
        #Loss and accuracy before training
        if epoch == 0:
            val_loss= model(opt_pars, ssn_pars, grid_pars, conn_pars, test_data, filter_pars,  conv_pars )
            print('Before training  -- loss: {} '.format(val_loss))
            val_loss_per_epoch.append(val_loss)

        
        start_time = time.time()
        train_iterator = iter(train_dataloader)
        test_iterator = iter(test_dataloader)
        
        for batch in range(batches):
            train_data = next(train_iterator)

            train_data['ref'] = train_data['ref'].numpy()
            train_data['target'] = train_data['target'].numpy()
            train_data['label'] = train_data['label'].numpy()

            #compute loss and gradient 
            grad =jax.grad(model)(opt_pars, ssn_pars, grid_pars, conn_pars, train_data, filter_pars,  conv_pars)

            #Apply SGD through Adam optimizer
            updates, opt_state = optimizer.update(grad, opt_state)
            opt_pars = optax.apply_updates(opt_pars, updates)
        
        epoch_time =time.time() - start_time
    
        #Evaluate model
        val_loss = model(opt_pars, ssn_pars, grid_pars, conn_pars, test_data, filter_pars,  conv_pars )
        print('Validation -- loss: {}, at epoch {}, (time {})'.format(val_loss, epoch+1, epoch_time))
        val_loss_per_epoch.append(val_loss)
    
    
    #reparametize parameters
    signs=np.array([[1, -1], [1, -1]])    
    opt_pars['logJ_2x2'] = np.exp(opt_pars['logJ_2x2'])*signs
    opt_pars['logs_2x2'] = np.exp(opt_pars['logs_2x2'])
    opt_pars['J_2x2'] = opt_pars['logJ_2x2']
    opt_pars['s_2x2'] = opt_pars['logs_2x2']
    del opt_pars['logJ_2x2'], opt_pars['logs_2x2']
    
    return opt_pars, val_loss_per_epoch

In [70]:
#model_test = model(opt_pars, ssn_pars, grid_pars, conn_pars, gratings, labels, filter_pars,  **conv_pars)
new_pars, offset_5_loss= train_SSN(opt_pars, ssn_pars, grid_pars, conn_pars, train_dataloader, test_dataloader, filter_pars, conv_pars, batches = number_batches, epochs=20)

# Vmap implementation

Vmap implementation

In [57]:
def sigmoid(x):
    return 1/(1+np.exp(x))

def binary_loss(n, x):
    return - (n*np.log(x) + (1-n)*np.log(1-x))

def test_model(opt_pars, ssn_pars, grid_pars, conn_pars, train_data, filter_pars,  conv_pars):
    
    signs=np.array([[1, -1], [1, -1]])
    
    J_2x2 =np.exp(opt_pars['logJ_2x2'])*signs
    s_2x2 = np.exp(opt_pars['logs_2x2'])
    
    #Initialise network
    ssn=SSN2DTopoV1_AMPAGABA_ONOFF(ssn_pars=ssn_pars, grid_pars=grid_pars, conn_pars=conn_pars, filter_pars=filter_pars, J_2x2=J_2x2, s_2x2=s_2x2)

    #Apply Gabor filters to stimuli
    output_ref=np.matmul(ssn.gabor_filters, train_data['ref'])*ssn.A
    output_target=np.matmul(ssn.gabor_filters, train_data['target'])*ssn.A

    #Rectify output
    SSN_input_ref=np.maximum(0, output_ref)
    SSN_input_target=np.maximum(0, output_target)

    #Input to SSN
    r_init = np.zeros(SSN_input_ref.shape[0])

    fp_ref, _ = ssn.fixed_point_r(SSN_input_ref, r_init=r_init, **conv_pars)
    x_ref = ssn.apply_bounding_box(fp_ref, size=3.2)

    fp_target, _ = ssn.fixed_point_r(SSN_input_target, r_init=r_init, **conv_pars)
    x_target = ssn.apply_bounding_box(fp_target, size=3.2)

    #Apply sigmoid function - combine ref and target
    x = sigmoid( np.dot(opt_pars['w_sig'], (x_ref.ravel() - x_target.ravel())) + opt_pars['b_sig'])

    #Calculate binary cross entropy loss
    loss=binary_loss(train_data['label'], x)
   
    return loss


def loss(opt_pars, ssn_pars, grid_pars, conn_pars, train_data, filter_pars,  conv_pars):
    '''
    Calculate parallelized loss for batch of data through vmap.
    Output:
        mean loss of all the input images
    '''
    
    vmap_model = vmap(test_model, in_axes = ({'b_sig': None, 'logJ_2x2': None, 'logs_2x2': None, 'w_sig': None}, None, None, {'PERIODIC': None, 'p_local': [None, None], 'sigma_oris': None},  {'ref':0, 'target':0, 'label':0}, {'conv_factor': None, 'degree_per_pixel': None, 'edge_deg': None, 'k': None, 'sigma_g': None}, {'Tmax': None, 'dt': None, 'silent': None, 'verbose': None, 'xtol': None}) )                   
    loss = np.sum(vmap_model(opt_pars, ssn_pars, grid_pars, conn_pars, train_data, filter_pars,  conv_pars))

    return loss


def train_SSN_vmap(opt_pars, ssn_pars, grid_pars, conn_pars, train_dataloader, test_dataloader, filter_pars, conv_pars, batches, epochs=1, eta=10e-4):
    
    #Initialize loss
    val_loss_per_epoch = []
    training_losses=[]
    accuracies=[]
    
    #Initialise optimizer
    optimizer = optax.adam(eta)
    opt_state = optimizer.init(opt_pars)
    
    #Define test data - no need to iterate
    test_iterator = iter(test_dataloader)
    test_data = next(test_iterator)
    test_data['ref'] = test_data['ref'].numpy()
    test_data['target'] = test_data['target'].numpy()
    test_data['label'] = test_data['label'].numpy()
    
    val_loss, accuracy= vmap_eval(opt_pars, ssn_pars, grid_pars, conn_pars, test_data, filter_pars,  conv_pars)

    print('Before training  -- loss: {}, accuracy: {} '.format(val_loss, accuracy))
    val_loss_per_epoch.append(val_loss)
    accuracies.append(accuracy)

    
    for epoch in range(epochs):
        start_time = time.time()
        train_iterator = iter(train_dataloader)
        epoch_loss = 0 
           
        #Iterate through data in batches
        for batch in range(batches): 

            #load next batch of data and convert
            train_data = next(train_iterator)
            #convert tensors to numpy
            train_data['ref'] = train_data['ref'].numpy()
            train_data['target'] = train_data['target'].numpy()
            train_data['label'] = train_data['label'].numpy()
            
            #Compute loss and gradient
            batch_loss, grad =jax.value_and_grad(loss)(opt_pars, ssn_pars, grid_pars, conn_pars, train_data, filter_pars,  conv_pars)
            
            #Apply SGD through Adam optimizer per batch
            updates, opt_state = optimizer.update(grad, opt_state)
            opt_pars = optax.apply_updates(opt_pars, updates)
            epoch_loss+=batch_loss
        
        epoch_time = time.time() - start_time
        
        
        #Evaluate model at the end of each epoch
        val_loss, accuracy= vmap_eval(opt_pars, ssn_pars, grid_pars, conn_pars, test_data, filter_pars,  conv_pars)
        print('Training loss: {} ¦ Validation -- loss: {}, accuracy: {} at epoch {}, (time {})'.format(epoch_loss, val_loss, accuracy, epoch+1, epoch_time))
        val_loss_per_epoch.append(val_loss)
        training_losses.append(epoch_loss)
        accuracies.append(accuracy)
    
    #reparametize parameters
    signs=np.array([[1, -1], [1, -1]])    
    opt_pars['logJ_2x2'] = np.exp(opt_pars['logJ_2x2'])*signs
    opt_pars['logs_2x2'] = np.exp(opt_pars['logs_2x2'])
    
    return opt_pars, val_loss_per_epoch, accuracies


def eval_model(opt_pars, ssn_pars, grid_pars, conn_pars, test_data, filter_pars,  conv_pars):
    signs=np.array([[1, -1], [1, -1]])
    
    J_2x2 =np.exp(opt_pars['logJ_2x2'])*signs
    s_2x2 = np.exp(opt_pars['logs_2x2'])
    
    #Initialise network
    ssn=SSN2DTopoV1_AMPAGABA_ONOFF(ssn_pars=ssn_pars, grid_pars=grid_pars, conn_pars=conn_pars, filter_pars=filter_pars, J_2x2=J_2x2, s_2x2=s_2x2)

    #Apply Gabor filters to stimuli
    output_ref=np.matmul(ssn.gabor_filters, test_data['ref'])*ssn.A
    output_target=np.matmul(ssn.gabor_filters, test_data['target'])*ssn.A

    #Rectify output
    SSN_input_ref=np.maximum(0, output_ref)
    SSN_input_target=np.maximum(0, output_target)

    #Input to SSN
    r_init = np.zeros(SSN_input_ref.shape[0])

    fp_ref, _ = ssn.fixed_point_r(SSN_input_ref, r_init=r_init, **conv_pars)
    x_ref = ssn.apply_bounding_box(fp_ref, size=3.2)

    fp_target, _ = ssn.fixed_point_r(SSN_input_target, r_init=r_init, **conv_pars)
    x_target = ssn.apply_bounding_box(fp_target, size=3.2)

    #Apply sigmoid function - combine ref and target
    x = sigmoid( np.dot(opt_pars['w_sig'], (x_ref.ravel() - x_target.ravel())) + opt_pars['b_sig'])
    
    #compare prediction to label
    pred_label = np.round(x)

    #Calculate binary cross entropy loss
    loss=binary_loss(test_data['label'], x)
    
    return loss, pred_label

def vmap_eval(opt_pars, ssn_pars, grid_pars, conn_pars, test_data, filter_pars,  conv_pars):
    
    eval_vmap = vmap(eval_model, in_axes = ({'b_sig': None, 'logJ_2x2': None, 'logs_2x2': None, 'w_sig': None}, None, None, {'PERIODIC': None, 'p_local': [None, None], 'sigma_oris': None},  {'ref':0, 'target':0, 'label':0}, {'conv_factor': None, 'degree_per_pixel': None, 'edge_deg': None, 'k': None, 'sigma_g': None}, {'Tmax': None, 'dt': None, 'silent': None, 'verbose': None, 'xtol': None}) )
    losses, pred_labels = eval_vmap(opt_pars, ssn_pars, grid_pars, conn_pars, test_data, filter_pars,  conv_pars)
    
    accuracy = np.sum(test_data['label'] == pred_labels)/len(test_data['label']) 
    
    vmap_loss= np.mean(losses)
    
    return vmap_loss, accuracy

In [77]:
#Network parameters
class ssn_pars():
    n = 2
    k = 0.04
    tauE = 30 # in ms
    tauI = 10 # in ms
    psi = 0.774
    tau_s = np.array([5, 7, 100]) #in ms, AMPA, GABA, NMDA current decay time constants
    

#Grid parameters
class grid_pars():
    gridsize_Nx = 9 # grid-points across each edge # gives rise to dx = 0.8 mm
    gridsize_deg = 2 * 1.6 # edge length in degrees
    magnif_factor = 2  # mm/deg
    hyper_col = 0.8 # mm   
    sigma_RF = 0.4 # deg (visual angle)

# Caleb's params for the full (with local) model:
Js0 = [1.82650658, 0.68194475, 2.06815311, 0.5106321]
gE, gI = 0.57328625, 0.26144141

sigEE, sigIE = 0.2, 0.40
sigEI, sigII = .09, .09
conn_pars = dict(
    PERIODIC = False,
    p_local = [.4, 0.7], # [p_local_EE, p_local_IE],
    sigma_oris = 1000) # sigma_oris


make_J2x2 = lambda Jee, Jei, Jie, Jii: np.array([[Jee, -Jei], [Jie,  -Jii]]) * np.pi * ssn_pars.psi
J_2x2 = make_J2x2(*Js0)
s_2x2 = np.array([[sigEE, sigEI],[sigIE, sigII]])

#Positive reparameterization
signs=np.array([[1, -1], [1, -1]])
logJ_2x2 =np.log(J_2x2*signs)
logs_2x2 = np.log(s_2x2)

accuracies=[]

for i in range(100):
    #Sigmoid parameters
    N_neurons = 25
    #key = random.PRNGKey(7)
    key, _ = random.split(key)
    w_sig = random.normal(key, shape = (N_neurons,)) / np.sqrt(N_neurons)
    b_sig = 0.0

    print(w_sig[0])

    #Optimization pars
    opt_pars = dict(logJ_2x2 = logJ_2x2, logs_2x2 = logs_2x2, w_sig = w_sig, b_sig=b_sig)

    #Parameters exclusive to Gabor filters
    filter_pars = dict(sigma_g = sigma_g, conv_factor = grid_pars.magnif_factor)
    filter_pars.update(general_pars) 

    #Convergence parameters
    conv_pars=dict(dt = 1, xtol = 1e-5, Tmax = 200, verbose=False, silent=True)


    test_iterator = iter(test_dataloader)
    test_data = next(test_iterator)
    test_data['ref'] = test_data['ref'].numpy()
    test_data['target'] = test_data['target'].numpy()
    test_data['label'] = test_data['label'].numpy()

    val_loss, accuracy= vmap_eval(opt_pars, ssn_pars, grid_pars, conn_pars, test_data, filter_pars,  conv_pars)
    accuracies.append(accuracy)

    #print('Before training  -- loss: {}, accuracy: {} '.format(val_loss, accuracy))

print('mean accuracy {}'.format(np.mean(np.array(accuracies))))
print('std accuracy {}'.format(np.std(np.array(accuracies))))

0.029492578
0.17246619
0.155092
0.06509497
-0.035398982
-0.10405494
-0.3808349
0.022488013
-0.35331386
-0.028675968
-0.18402055
-0.4617661
0.052420754
0.07784064
-0.34543487
-0.07090875
0.078980304
-0.16584781
-0.44027418
-0.16659573
-0.2857178
-0.27898508
0.31699798
-0.16648431
0.17404361
0.06896632
0.06442323
0.09719102
0.091520265
-0.15207894
-0.14929603
0.49952355
0.28899094
-0.017035412
0.07258329
0.03103529
-0.0013342563
0.045277845
0.055535436
-0.16484976
-0.030420557
0.015541026
0.29235873
-0.21536985
-0.16431247
0.10392799
-0.035985433
-0.18304285
0.24859548
0.09624269
0.121746615
-0.06740898
0.5067969
0.1442144
-0.07963194
-0.2553699
0.014697698
0.03550355
-0.19317183
-0.26313394
0.022527648
0.31488392
-0.21827042
-0.1459861
0.15279333
0.0018651752
-0.27461946
-0.28915852
-0.08333664
-0.020681921
0.13305351
-0.25832534
-0.18998267
0.4046653
-0.06035649
0.19500443
-0.15255456
-0.12465968
-0.057451893
0.011458481
-0.22376874
-0.07796078
-0.18848473
0.3088711
-0.022337638
-0.208

contrast = 0.8
mean accuracy 0.5005999803543091
std accuracy0.1579735428094864

contrast = 0.99 mean accuracy 0.4772000014781952
std accuracy 0.14111045002937317

In [50]:
#OFFSET - 5, eta = 10e-4
vmap_pars, vmap_val_loss, acc = train_SSN_vmap(opt_pars, ssn_pars, grid_pars, conn_pars, train_dataloader, test_dataloader, filter_pars,  conv_pars, batches = number_batches, epochs = 20)

Before training  -- loss: 0.5459374785423279, accuracy: 0.6600000262260437 
Training loss: 104.8504638671875 ¦ Validation -- loss: 0.48505544662475586, accuracy: 0.6800000071525574 at epoch 1, (time 55.28361701965332)
Training loss: 93.47003173828125 ¦ Validation -- loss: 0.41788673400878906, accuracy: 0.7599999904632568 at epoch 2, (time 55.12404465675354)
Training loss: 81.01006317138672 ¦ Validation -- loss: 0.34719958901405334, accuracy: 0.800000011920929 at epoch 3, (time 55.708760499954224)
Training loss: 68.13202667236328 ¦ Validation -- loss: 0.2796390950679779, accuracy: 0.8199999928474426 at epoch 4, (time 55.38471078872681)
Training loss: 50.94453430175781 ¦ Validation -- loss: 0.1730218529701233, accuracy: 0.8999999761581421 at epoch 5, (time 55.07560420036316)
Training loss: 19.09897232055664 ¦ Validation -- loss: 0.06391000747680664, accuracy: 1.0 at epoch 6, (time 54.87607789039612)
Training loss: 4.007383346557617 ¦ Validation -- loss: 0.010353709571063519, accuracy: 1.

KeyboardInterrupt: 

In [59]:
#OFFSET - 05, eta = 10e-4
vmap_pars_05, vmap_val_loss_05, acc_05 = train_SSN_vmap(opt_pars, ssn_pars, grid_pars, conn_pars, train_dataloader, test_dataloader, filter_pars,  conv_pars, batches = number_batches, epochs = 20)

Before training  -- loss: 0.7708415985107422, accuracy: 0.3400000035762787 
Training loss: 155.9490966796875 ¦ Validation -- loss: 0.7486924529075623, accuracy: 0.3400000035762787 at epoch 1, (time 53.977357149124146)
Training loss: 152.0327911376953 ¦ Validation -- loss: 0.7348071336746216, accuracy: 0.36000001430511475 at epoch 2, (time 53.79437446594238)
Training loss: 149.38291931152344 ¦ Validation -- loss: 0.7255972623825073, accuracy: 0.3799999952316284 at epoch 3, (time 54.45169234275818)
Training loss: 147.525146484375 ¦ Validation -- loss: 0.7191407680511475, accuracy: 0.4399999976158142 at epoch 4, (time 53.74654006958008)
Training loss: 146.1613006591797 ¦ Validation -- loss: 0.7143919467926025, accuracy: 0.4399999976158142 at epoch 5, (time 54.12783145904541)
Training loss: 145.11245727539062 ¦ Validation -- loss: 0.7107369899749756, accuracy: 0.4399999976158142 at epoch 6, (time 53.79789471626282)
Training loss: 144.2670440673828 ¦ Validation -- loss: 0.7077839374542236, 

In [65]:
#OFFSET - 05, eta = 10e-4
vmap_pars_0_25, vmap_val_loss_05_2, acc_05_2 = train_SSN_vmap(opt_pars, ssn_pars, grid_pars, conn_pars, train_dataloader, test_dataloader, filter_pars,  conv_pars, batches = number_batches, epochs = 20)

Before training  -- loss: 0.6479934453964233, accuracy: 0.699999988079071 
Training loss: 137.92486572265625 ¦ Validation -- loss: 0.6421316862106323, accuracy: 0.7400000095367432 at epoch 1, (time 54.21994471549988)
Training loss: 136.59756469726562 ¦ Validation -- loss: 0.6350752115249634, accuracy: 0.7799999713897705 at epoch 2, (time 55.3475604057312)
Training loss: 135.12440490722656 ¦ Validation -- loss: 0.6258025765419006, accuracy: 0.8199999928474426 at epoch 3, (time 55.09902834892273)
Training loss: 133.29287719726562 ¦ Validation -- loss: 0.6125142574310303, accuracy: 0.8199999928474426 at epoch 4, (time 54.1776180267334)
Training loss: 130.800537109375 ¦ Validation -- loss: 0.5912361145019531, accuracy: 0.8399999737739563 at epoch 5, (time 53.83504009246826)
Training loss: 127.19683837890625 ¦ Validation -- loss: 0.5565797090530396, accuracy: 0.8600000143051147 at epoch 6, (time 55.08782172203064)
Training loss: 121.86160278320312 ¦ Validation -- loss: 0.5024295449256897, a