# Marino's model in Marino's parameter settings

In [1]:
# In its own cell so we can run it just once

include("pro_anti.jl")   # Loads all the necessary pre-requisites


load_run

In [38]:
# Define core model parameters
model_params = Dict(
:dt     =>  0.002, 
:tau    =>  0.02, 
:vW     =>  -1.58,
:hW     =>  -0.05,
:sW     =>  0,
:dW     =>  0,
:nsteps =>  301, 
:noise  =>  [], 
:sigma  =>  0.08, 
:input  =>  0, 
:g_leak =>  1, 
:U_rest =>  0,
:theta  =>  0.05, 
:beta   =>  0.5, 
:constant_excitation      => 0, 
:anti_rule_strength       => 0.05,
:pro_rule_strength        => 0.05, 
:target_period_excitation => 0,
:right_light_excitation   => 0.6, 
:right_light_pro_extra    => 0,
:const_add                => 0, 
:init_add                 => 0, 
:rule_and_delay_period    => 0.2,
:target_period            => 0.1,
:post_target_period       => 0.1,
:const_pro_bias           => 0.0427,
:nPro                     => 100,
:nAnti                    => 100,
:theta1                   => 0.05,
:theta2                   => 0.15,
:opto_strength  => .9,
:opto_periods   => [
    0               0 ; 
    0              "trial_end" ;
    "target_start/2"  "target_start";
    "target_start"    "target_end"],  
# set of opto conditions, in seconds, with 0 the start 
# of the trial (i.e. start of rule_and_delay_period), anything before 0 or after end of trial gets ignored.
#:opto_targets   => [.75 .73;.77 .58;.75 .74;.72 .66;.73 .75] 
:opto_targets => [
    .9      .7; 
    .9      .5; 
    .9      .55; 
    .9      .7;
    ],  # first column is frachit Pro, next column is Anti, rows are conditions
# The "conditions" correspond to the rows of opto_periods.
);


# ======= ARGUMENTS AND SEED VALUES:
args = ["sW", "vW", "hW", "dW", "constant_excitation", "right_light_excitation", "target_period_excitation"]
seed = [0.001,  -1.58,   -0.05,  0.001,    0.001,                0.6,                       0.001]   
args = [args ; ["const_pro_bias", "sigma","opto_strength", "pro_rule_strength", "anti_rule_strength"]];
seed = [seed ; [  0.0427,            .05,      .9,              0.05,                0.05]];

# ======= BOUNDING BOX:
bbox = Dict(:sW=>[0 3], :vW=>[-3 3], :hW=>[-3 3], :dW=>[-3 3], :constant_excitation=>[-2 2],
:right_light_excitation=>[0.05 4], :target_period_excitation=>[0 4], :const_pro_bias=>[-2 2],
:sigma=>[0.01 0.2],:opto_strength=>[0 1]);

# ======== SEARCH ZONE:
sbox = Dict(:sW=>[0 .5], :vW=>[-.5 .5], :hW=>[-.5 .5], :dW=>[-.5 .5],
:constant_excitation=>[-.5 .5], :right_light_excitation=>[0.05 .5], :target_period_excitation=>[0.001 .5],:const_pro_bias=>[-.5 .5], :sigma=>[0.02 0.19],:opto_strength=>[.7 .99]);

# define a few hyper parameters
cbetas = [0.04];
rule_and_delay_periods = [0.2];
post_target_periods    = [0.1];
num_eval_runs           = 100;
num_optimize_iter       = 2000;
num_optimize_restarts   = 1;


cb = 0.04

# figure out initial seed for random number generator
sr = convert(Int64, round(time()))
srand(sr);

mypars = merge(model_params, Dict(:opto_times => ["target_start", "target_end"], :rule_and_delay_period=>0.2,
:anti_rule_strength=>0.06))
# mypars = merge(model_params, Dict(:opto_times => [0 0], :rule_and_delay_period=>0.1,
# :anti_rule_strength=>0.06))
pygui(true)
proVs, antiVs, pro_fullV, anti_fullV, opto_fraction, pro_input, anti_input = run_ntrials(15, 15; 
plot_list=[1:15;], profig=1, antifig=2, opto_units = 1:4, mypars...)

# define opto function with just value output
func =  (;params...) -> JJ(model_params[:nPro], model_params[:nAnti]; 
    rule_and_delay_periods=rule_and_delay_periods, theta1=model_params[:theta1], theta2=model_params[:theta2], 
    post_target_periods=post_target_periods, seedrand=sr, cbeta=cb, verbose=true, 
    merge(mypars, Dict(params))...)[1]


# cost, cost1s, cost2s, hP, hA, dP, dA, hBP, hBA = JJ(model_params[:nPro], model_params[:nAnti]; 
#     rule_and_delay_periods=rule_and_delay_periods, theta1=model_params[:theta1], theta2=model_params[:theta2], 
#     post_target_periods=post_target_periods, seedrand=sr, cbeta=cb, verbose=true, model_params...)
    

In [39]:
@printf("Going with seed = "); print_vector_g(seed); print("\n")
pars, traj, cost, cpm_traj, ftraj = bbox_Hessian_keyword_minimization(seed, args, bbox, func, 
    start_eta = 1, tol=1e-12, verbose=true, verbose_every=1, maxiter=num_optimize_iter)
@printf("Came out with cost %g and pars = ", cost); print_vector_g(pars); print("\n\n")


Going with seed = [0.001, -1.58, -0.05, 0.001, 0.001, 0.6, 0.001, 0.0427, 0.05, 0.9, 0.05, 0.05]
0: eta=1 ps=[0.001, -1.580, -0.050, 0.001, 0.001, 0.600, 0.001, 0.043, 0.050, 0.900, 0.050, 0.050]
Opto condition # 1
     - 1 - cost=0.00967201, cost1=0.0128087, cost2=-0.00313672
     - 1 - mean(hitsP)=0.969158, mean(diffsP)=0.278507 mean(hitsA)=0.387451, mean(diffsA)=0.348837
Opto condition # 2
     - 2 - cost=0.0117545, cost1=0.0139329, cost2=-0.00217838
     - 2 - mean(hitsP)=0.967123, mean(diffsP)=0.268395 mean(hitsA)=0.172956, mean(diffsA)=0.167281
Opto condition # 3
     - 3 - cost=0.0126378, cost1=0.0158546, cost2=-0.00321686
     - 3 - mean(hitsP)=0.973269, mean(diffsP)=0.297372 mean(hitsA)=0.201476, mean(diffsA)=0.346
Opto condition # 4
     - 4 - cost=0.0212839, cost1=0.0241292, cost2=-0.00284533
     - 4 - mean(hitsP)=0.976805, mean(diffsP)=0.316004 mean(hitsA)=0.267409, mean(diffsA)=0.253061
OVERALL
     -- cost=0.013837, cost1=0.0166814, cost2=-0.00284432
Opto condition # 1
 

LoadError: [91mInterruptException:[39m

In [None]:
# run optimization with all parameters
@printf("Going with seed = "); print_vector_g(seed); print("\n")
pars, traj, cost, cpm_traj, ftraj = bbox_Hessian_keyword_minimization(seed, args, bbox, func, 
    start_eta = 1, tol=1e-12, verbose=true, verbose_every=1, maxiter=num_optimize_iter)
@printf("Came out with cost %g and pars = ", cost); print_vector_g(pars); print("\n\n")

    # get gradient and hessian at end of optimization 
    value, grad, hess = keyword_vgh(func, args, pars)

    # define function with all outputs, evaluate on training noise
    t_standard_func =  (;params...) -> JJ_opto(model_params[:nPro], model_params[:nAnti]; rule_and_delay_periods=rule_and_delay_periods, theta1=model_params[:theta1], theta2=model_params[:theta2], post_target_periods=post_target_periods, seedrand=sr, cbeta=cb, verbose=false, merge(model_params, Dict(params))...)

    # run opto model with all outputs, evaluate on training noise
    t_opto_scost, t_opto_scost1, t_opto_scost2, t_opto_hitsP,t_opto_hitsA, t_opto_diffsP, t_opto_diffsA, t_opto_bP, t_opto_bA = t_standard_func(;make_dict(args, pars, model_params)...)
    
   ## evaluate for long form info
    # reset random number generator for testing purposes
    test_sr = convert(Int64, round(time()))
    srand(test_sr); 

    # define function with all outputs, evaluate on test noise
    standard_func =  (;params...) -> JJ_opto(num_eval_runs, num_eval_runs; rule_and_delay_periods=rule_and_delay_periods, theta1=model_params[:theta1], theta2=model_params[:theta2], post_target_periods=post_target_periods, seedrand=test_sr, cbeta=cb, verbose=false, merge(model_params, Dict(params))...)

    # run opto model with all outputs, evaluate on test noise
    opto_scost, opto_scost1, opto_scost2, opto_hitsP,opto_hitsA, opto_diffsP, opto_diffsA, opto_bP, opto_bA = standard_func(;make_dict(args, pars, model_params)...)

    # define non-opto model with all outputs, to check opto model, evaluate on test noise
    standard_func =  (;params...) -> JJ(num_eval_runs, num_eval_runs; rule_and_delay_periods=rule_and_delay_periods, theta1=model_params[:theta1], theta2=model_params[:theta2], post_target_periods=post_target_periods, seedrand=test_sr, cbeta=cb, verbose=false, merge(model_params, Dict(params))...)

    # run non-opto model with all outputs, evaluate on test noise
    scost, scost1, scost2, hitsP,hitsA, diffsP, diffsA = standard_func(;make_dict(args, pars, model_params)...)
 
   ## Save this run out to a file
    # get filename
    myfilename = next_file(fbasename, 4)
    myfilename = myfilename*".mat"
    # write file
    matwrite(myfilename, Dict("args"=>args, "myseed"=>myseed, "dista"=>dista, "pars"=>pars, "traj"=>traj, "cost"=>cost, "cpm_traj"=>cpm_traj, "nPro"=>model_params[:nPro], "nAnti"=>model_params[:nAnti], "sr"=>sr, "cb"=>cb, "theta1"=>model_params[:theta1], "theta2"=>model_params[:theta2],"value"=>value,"grad"=>grad, "hess"=>hess, "scost"=>scost, "scost1"=>scost1, "scost2"=>scost2, "hitsP"=>hitsP,"hitsA"=>hitsA, "diffsP"=>diffsP, "diffsA"=>diffsA, "model_params"=>ascii_key_ize(model_params), "bbox"=>ascii_key_ize(bbox), "sbox"=>ascii_key_ize(sbox), "rule_and_delay_periods"=>rule_and_delay_periods, "post_target_periods"=>post_target_periods, "opto_scost"=>opto_scost, "opto_scost1"=>opto_scost1, "opto_scost2"=>opto_scost2, "opto_hitsP"=>opto_hitsP, "opto_hitsA"=>opto_hitsA, "opto_diffsP"=>opto_diffsP, "opto_diffsA"=>opto_diffsA,"test_sr"=>test_sr,"opto_bP"=>opto_bP, "opto_bA"=>opto_bA, "t_opto_scost"=>t_opto_scost, "t_opto_scost1"=>t_opto_scost1, "t_opto_scost2"=>t_opto_scost2, "t_opto_hitsP"=>t_opto_hitsP, "t_opto_hitsA"=>t_opto_hitsA, "t_opto_diffsP"=>t_opto_diffsP, "t_opto_diffsA"=>t_opto_diffsA,"t_opto_bP"=>t_opto_bP, "t_opto_bA"=>t_opto_bA  ))
end
end