In [77]:
# testing model output with phase matched stimuli

import numpy
import jax.numpy as np
import matplotlib.pyplot as plt
import os
import sys
numpy.random.seed(0)

from training.util_gabor import init_untrained_pars, create_gabor_filters_ori_map
from analysis.analysis_functions import tuning_curve, SGD_step_indices
from util import load_parameters
from parameters import (
    grid_pars,
    filter_pars,
    stimuli_pars,
    readout_pars,
    ssn_pars,
    conv_pars,
    training_pars,
    loss_pars,
    pretrain_pars # Setting pretraining to be true (pretrain_pars.is_on=True) should happen in parameters.py because w_sig depends on it
)

# Checking that pretrain_pars.is_on is on
if not pretrain_pars.is_on:
    raise ValueError('Set pretrain_pars.is_on to True in parameters.py to run training with pretraining!')


tc_ori_list = numpy.arange(0,180,2)
num_training = 10
final_folder_path = os.path.join('results','Apr10_v1')
########## Calculate and save gabor outputs ############
num_phases=4
i = 0
# Load orimap
orimap_filename = os.path.join(final_folder_path, f"orimap_{i}.npy")
orimap_loaded = numpy.load(orimap_filename)

# Calculate gaboor filters
untrained_pars = init_untrained_pars(grid_pars, stimuli_pars, filter_pars, ssn_pars, conv_pars, 
                loss_pars, training_pars, pretrain_pars, readout_pars, orimap_loaded= orimap_loaded)
results_filename = os.path.join(final_folder_path, f"results_{i}.csv")
trained_pars_stage1, trained_pars, offset_last = load_parameters(results_filename, iloc_ind = 0)

In [78]:
# Calculating model response
from util import sep_exponentiate
from training.util_gabor import BW_image_jax_supp, BW_image_vmap
from training.SSN_classes import SSN_mid, SSN_sup
from training.model import vmap_evaluate_model_response_mid, vmap_evaluate_model_response

ori_vec=np.arange(0,180,2)
ref_ori_saved = float(untrained_pars.stimuli_pars.ref_ori)
if 'log_J_2x2_m' in trained_pars:
    J_2x2_m = sep_exponentiate(trained_pars['log_J_2x2_m'])
    J_2x2_s = sep_exponentiate(trained_pars['log_J_2x2_s'])
if 'J_2x2_m' in trained_pars:
    J_2x2_m = trained_pars['J_2x2_m']
    J_2x2_s = trained_pars['J_2x2_s']
if 'c_E' in trained_pars:
    c_E = trained_pars['c_E']
    c_I = trained_pars['c_I']
else:
    c_E = untrained_pars.ssn_pars.c_E
    c_I = untrained_pars.ssn_pars.c_I        
if 'log_f_E' in trained_pars:  
    f_E = np.exp(trained_pars['log_f_E'])
    f_I = np.exp(trained_pars['log_f_I'])
elif 'f_E' in trained_pars:
    f_E = trained_pars['f_E']
    f_I = trained_pars['f_I']
else:
    f_E = untrained_pars.ssn_pars.f_E
    f_I = untrained_pars.ssn_pars.f_I

ssn_mid=SSN_mid(ssn_pars=untrained_pars.ssn_pars, grid_pars=untrained_pars.grid_pars, J_2x2=J_2x2_m)

num_ori = len(ori_vec)
new_rows = []
x_map = untrained_pars.grid_pars.x_map
y_map = untrained_pars.grid_pars.y_map
ssn_pars = untrained_pars.ssn_pars
grid_size = x_map.shape[0]*x_map.shape[1]
responses_mid_phase_match = numpy.zeros((len(ori_vec),grid_size*ssn_pars.phases*2))
responses_sup_phase_match = numpy.zeros((len(ori_vec),grid_size*2))
#for i in range(x_map.shape[0]):
#    for j in range(x_map.shape[1]):
#        for phase_ind in range(ssn_pars.phases):
i=0
j=0
phase_ind = 0
# Generate stimulus
x0 = x_map[i, j]
y0 = y_map[i, j]
BW_image_jax_inp = BW_image_jax_supp(stimuli_pars, x0=x0, y0=y0, phase=phase_ind * np.pi/2, full_grating=True)
x = BW_image_jax_inp[4]
y = BW_image_jax_inp[5]
alpha_channel = BW_image_jax_inp[6]
mask = BW_image_jax_inp[7]
train_data = BW_image_vmap(BW_image_jax_inp[0:4], x, y, alpha_channel, mask, ori_vec, np.zeros(num_ori))
# Calculate model response for middle layer cells and save it to responses_mid_phase_match
output_mid, _, _, _, _, responses_mid = vmap_evaluate_model_response_mid(ssn_mid, train_data, untrained_pars.conv_pars, c_E, c_I, untrained_pars.gabor_filters)



In [84]:
print(responses_mid[:,0]+responses_mid[:,162]+responses_mid[:,324]+responses_mid[:,486])
print(output_mid[:,0])

[160.78006   180.52899   185.29173   171.91898   142.86722   107.79882
  84.06265    68.30645    53.13357    41.5691     37.56923    35.07416
  29.229929   26.886387   52.828045   75.11182    72.204575   54.45756
  30.309048   13.105339   10.432635    9.441031    9.41687    10.708129
  13.190528   15.62453    16.313242   19.369041   24.95577    31.696783
  37.444675   23.939312   11.406364   13.1194     12.507507    8.098734
   5.5139823   5.2599053   5.5070515   5.810134    6.365381   12.3858185
  31.247892   41.960205   41.136223   29.220032   14.75413    12.457674
  12.124348   12.728067   13.127708   14.078316   15.195295   16.519722
  21.634674   40.403984   60.680843   77.477036   89.310585   96.27819
  99.12331    99.07792    96.984665   94.29112    91.72234    89.96203
  89.45399    90.338425   92.70887    96.46286   101.30008   107.03344
 113.30092   119.21159   124.20784   127.05716   127.133675  123.538666
 115.84911   104.289856   90.335045   78.47967    76.499176   79.4373

In [None]:
# Trying to figure out why does output_mid (sum of E cells) have multiple bumps - copying code from vmap_evaluate_model_response_mid
from training.model import constant_to_vec
from training.SSN_classes import obtain_fixed_point
constant_vector = constant_to_vec(c_E=c_E, c_I=c_I, ssn=ssn_mid)

# Apply Gabor filters to stimuli to create input of middle layer
input_mid = np.matmul(untrained_pars.gabor_filters, train_data)

# Rectify middle layer input before fix point calculation
SSN_mid_input = np.maximum(0, input_mid) + constant_vector

Rmax_E=40
Rmax_I=80
fp, avg_dx = obtain_fixed_point(ssn=ssn_mid, ssn_input = SSN_mid_input, conv_pars = untrained_pars.conv_pars)

map_numbers_E = np.arange(1, 2 * ssn_mid.phases, 2) # 1,3,5,7
map_numbers_I = np.arange(2, 2 * ssn_mid.phases + 1, 2) # 2,4,6,8

fp_E=ssn_mid.select_type(fp, map_numbers_E)
fp_I=ssn_mid.select_type(fp, map_numbers = map_numbers_I)

#Define output as sum of E neurons
layer_output = np.sum(fp_E, axis=0)
'''
mid_cell_ind = i*x_map.shape[0]*ssn_pars.phases*2+j*ssn_pars.phases*2+phase_ind*2
responses_mid_phase_match[:,mid_cell_ind]=responses_mid[:,mid_cell_ind] # E cell
responses_mid_phase_match[:,mid_cell_ind+1]=responses_mid[:,mid_cell_ind+1] # I cell
# Calculate model response for superficial layer cells and save it to responses_sup_phase_match
if phase_ind==0:
    # Superficial layer response per grid point
    ssn_sup=SSN_sup(ssn_pars=ssn_pars, grid_pars=untrained_pars.grid_pars, J_2x2=J_2x2_s, p_local=ssn_pars.p_local_s, oris=untrained_pars.oris, s_2x2=ssn_pars.s_2x2_s, sigma_oris = ssn_pars.sigma_oris, ori_dist = untrained_pars.ori_dist, train_ori = untrained_pars.stimuli_pars.ref_ori)
    _, _, [_, _], [_, _], [_, _, _, _], [fp_mid, responses_sup] = vmap_evaluate_model_response(ssn_mid, ssn_sup, train_data, untrained_pars.conv_pars, c_E, c_I, f_E, f_I, untrained_pars.gabor_filters)
    sup_cell_ind = i*x_map.shape[0]+j
    responses_sup_phase_match[:,sup_cell_ind]=responses_sup[:,sup_cell_ind]
    responses_sup_phase_match[:,grid_size+sup_cell_ind]=responses_sup[:,grid_size+sup_cell_ind]
'''