# INTRODUCTION

This notebook tries to reproduce what Marino did where he got good training even with tens of trials from a variety of starting points.  It uses his run_dynamics function.

# ==========


# Preliminaries

In [1]:
using PyCall
using PyPlot
using ForwardDiff
using DiffBase

pygui(true)

import Base.convert
convert(::Type{Float64}, x::ForwardDiff.Dual) = Float64(x.value)
function convert(::Array{Float64}, x::Array{ForwardDiff.Dual}) 
    y = zeros(size(x)); 
    for i in 1:prod(size(x)) 
        y[i] = convert(Float64, x[i]) 
    end
    return y
end

include("hessian_utils.jl")

"""
We define functions to convert Duals, the variable types used by ForwardDiff, 
to Floats. This is useful if we want to print out the value of a variable 
(since print doesn't know how to Duals). Note that after being converted to a Float, no
differentiation by ForwardDiff can happen!  e.g. after
    x = convert(Float64, y)
ForwardDiff can still differentiate y, but it can't differentiate x
"""





"We define functions to convert Duals, the variable types used by ForwardDiff, \nto Floats. This is useful if we want to print out the value of a variable \n(since print doesn't know how to Duals). Note that after being converted to a Float, no\ndifferentiation by ForwardDiff can happen!  e.g. after\n    x = convert(Float64, y)\nForwardDiff can still differentiate y, but it can't differentiate x\n"

# Marino's main model dynamics function  (adapted from model_gradient_kwargs_fast.jl on 20-July-2017)

Note that the documenation indicates some default values for the optional parameters; these values need to be updated to what the code actually says below. The actual defaults are much closer to what Marino's May 2017 model has

In [73]:
"""
function cost, vec_ans_out, vec_react = run_dynamics(
                       target1=0.9,target2=0.9,target3=0.9,target4=0.6,target5=0.4,target6=0.6,
                       vwi=9, hwi=0.25, pro_bias=0.2135, opto_effect=0.9,
                       delay_input=0.25, light_input=3, noisefr=0.005, sigma=0.005,
                       threshold=0.18, ntrials=50, random_seed=321,
                       tau=17.6, dt=10, start_U = [-25, -25, -25, -25],
                       g_leak = 1, U_rest = 0, theta1 = 5, beta1 = 50, theta2=0.15, theta3=0.15,
                       cue_period = 200, delay_period = 200, choice_period = 200, nderivs=0, difforder=0,
                       do_plot = false, fignum=1, plot_trials=[])

"""
function run_dynamics( ;target1=0.9,target2=0.9,target3=0.9,target4=0.6,target5=0.4,target6=0.6,
                       vwi=9, hwi=0.25, pro_bias=0.2135, opto_effect=0.9,
                       delay_input=0.25, light_input=3, noisefr=0.005, sigma=0.005,
                       threshold=0.18, ntrials=50, random_seed=321,
                       tau=17.6, dt=10, start_U = [-25, -25, -25, -25],
                       g_leak = 1, U_rest = 0, theta1 = 5, beta1 = 50, theta2=0.15, theta3=0.15,
                       cue_period = 200, delay_period = 200, choice_period = 200, nderivs=0, difforder=0,
                       do_plot = false, fignum=1, plot_trials=[])

    vec_ans_out = ForwardDiffZeros(ntrials, 6; nderivs=nderivs, difforder=difforder)
    vec_reac = ForwardDiffZeros(ntrials,6; nderivs=nderivs, difforder=difforder)

    if isempty(plot_trials)
        plot_trials = [1:ntrials;]
    end
    
    if !isnan(random_seed)
        srand(random_seed)
    else  # if the random seed is passed as NaN, use the system time in milliseconds
        srand(convert(Int64, round(1000*time())))
    end


    titles = ["pro", "pro delay", "pro choice", "anti", "anti delay", "anti choice"]
    for jjj in [1:6;] #trial types: pro, pro delay, pro chioce, anti, anti delay, anti choice

        if (jjj==1)||(jjj==2)||(jjj==3) 
            trial_type="pro"
        else     
            trial_type="anti"
        end

        if (jjj==1)||(jjj==4)
            opto_delay=1;
            opto_choice=1;
        elseif (jjj==2)||(jjj==5)
            opto_delay=opto_effect;
            opto_choice=1;
        elseif (jjj==3)||(jjj==6)
            opto_delay=1;
            opto_choice=opto_effect;
        end
        
        for iii in [1:ntrials;]
            t = [0 : dt : cue_period + delay_period + choice_period;]

            V = ForwardDiffZeros(4, length(t); nderivs=nderivs, difforder=difforder)
            U = ForwardDiffZeros(4, length(t); nderivs=nderivs, difforder=difforder)

            U[:,1] = start_U

            W = [0 -vwi -hwi 0; -vwi 0 0 -hwi;
                 -hwi 0 0 -vwi; 0 -hwi -vwi 0]

            for i in [2:length(t);]  # the funny semicolon appears to be necessary in Julia
                dUdt = W * V[:,i-1] + g_leak*(U_rest - U[:,i-1])/tau
                
                if t[i] < cue_period + delay_period
                    if trial_type=="anti"
                        dUdt[[2,4]] += delay_input
                    elseif trial_type == "pro"
                        dUdt[[1,3]] += delay_input
                    else
                        error("invalid trial type")
                    end
                elseif t[i] < cue_period + delay_period + choice_period
                    dUdt[[1,2]] += light_input
                end
                
                dUdt[[1,3]] += pro_bias

                U[:,i] = U[:,i-1] +  dt*dUdt + sqrt(dt)*sigma*randn(4)
                V[:,i] = 0.5*tanh((U[:,i]-theta1)/beta1) + 0.5

                if t[i] < cue_period+delay_period
                    V[:,i]=V[:,i]*opto_delay
                elseif t[i] < cue_period+delay_period+choice_period
                    V[:,i]=V[:,i]*opto_choice
                end

                V[:,i] = V[:,i] + noisefr*sqrt(dt/10)*randn(4)
            end

            if trial_type=="anti"    
                answer_out  = 0.5*(1 + tanh.((V[3,end]  - V[1,end])/theta2))
            elseif trial_type == "pro"
                answer_out  = 0.5*(1 + tanh.((V[1,end]  - V[3,end])/theta2))
            else
                error("invalid trial type")
            end
            
            #compute reaction time
            reac=NaN;
            for i in [161:length(t)-15;]                
                val1=mean(V[1,i-15:i+15]);
                val3=mean(V[3,i-15:i+15]);
                if(abs(val1-val3)>threshold)
                    reac=t[i];
                    break
                end
            end

            vec_ans_out[iii,jjj]=answer_out;
            vec_reac[iii,jjj]=reac;
            
            if do_plot && ~isempty(find(plot_trials.==iii))
                figure(fignum); 
                ax = subplot(6,1,jjj)

                h = plot(t, V'); 
                setp(h[1], color=[0, 0, 1])
                setp(h[2], color=[1, 0, 0])
                setp(h[3], color=[1, 0.5, 0.5])
                setp(h[4], color=[0, 1, 1])
                ylabel("V")

                ax = gca()
                yl = [ylim()[1], ylim()[2]]
                vlines([cue_period, cue_period+delay_period, 
                    cue_period+delay_period+choice_period], 
                    -0.05, 1.05, linewidth=2)
                if yl[1]<0.02
                    yl[1] = -0.02
                end
                if yl[2]>0.98
                    yl[2] = 1.02
                end
                ylim(yl)
                grid(true)
                title(titles[jjj])
                
                if jjj==6; 
                    xlabel("t");  
                else
                    setp(ax, xticks=[])
                end
            end
            
        end  # end loop over trials
        
            
    end # end loop over trial types

    cost = (mean(vec_ans_out[:,1]) - target1)^2 + (mean(vec_ans_out[:,2]) - target2)^2 + 
        (mean(vec_ans_out[:,3]) - target3)^2 + (mean(vec_ans_out[:,4]) - target4)^2 + 
        (mean(vec_ans_out[:,5]) - target5)^2 + (mean(vec_ans_out[:,6]) - target6)^2 
        
    if do_plot
        for jjj=[1:6;]
            subplot(6,1,jjj)
            title(titles[jjj] * " " * string(round(100*mean(vec_ans_out[:,jjj]))) * " %")
        end
    end
    vec_ans_out=mean(vec_ans_out,1)

    return cost, vec_ans_out, vec_reac
end

figure(1); clf();
# run_dynamics(ntrials=120, do_plot=true, sigma=0.1, noisefr=0.0005, random_seed=NaN)
# MARINO PARAMS: with noise added to V, not U
# run_dynamics(ntrials=120, do_plot=true, sigma=0, noisefr=0.005, random_seed=NaN)

c = run_dynamics(ntrials=200, do_plot=true, sigma=0, noisefr=0.01, random_seed=NaN, dt=2, plot_trials=1:10)





In [8]:
c

(0.11068650092605362,
[0.998313 0.998413 … 0.260838 0.725305],

[NaN NaN … NaN NaN; NaN NaN … NaN NaN; … ; NaN NaN … NaN NaN; NaN NaN … NaN NaN])

In [3]:
a = "a"
b = [a * " and this"]
string(23.34566800001)
isnan(NaN)

function glug(;pt=NaN)
    if isnan(pt)
        @printf "Yep\n"
    end
end

glug()

Yep


In [74]:
args=["vwi", "hwi", "pro_bias", "delay_input"];
goods=[4.0, 4.0, 0.2 ,0.2];
bbox = [-10   10;
        -10   10;
        -4    4;
        -4    4]

func = (;pars...) -> run_dynamics(;ntrials=100, target1=0.8, target2=0.8, target3=0.8, 
target4=0.7, target5=0.5, target6=0.7, dt=2, noisefr=0.015, pars...)[1]

@time params, trajectory = bbox_Hessian_keyword_minimization(goods, args, bbox, func, verbose=true, 
start_eta=0.1, tol=1e-12)

figure(1); clf();
func(; do_plot=true, ntrials=1000, plot_trials=1:10, make_dict(args, params)...)


0: eta=0.1 ps=[4.000, 4.000, 0.200, 0.200]
1: eta=0.11 cost=1.3264 jtype=constrained costheta=-0.999 ps=[4.000, 3.991, 0.100, 0.202]
2: eta=0.121 cost=1.2991 jtype=constrained costheta=-0.999 ps=[4.010, 3.984, -0.009, 0.205]
3: eta=0.1331 cost=1.2303 jtype=constrained costheta=-0.998 ps=[4.033, 3.980, -0.128, 0.208]
4: eta=0.14641 cost=1.0686 jtype=constrained costheta=-0.998 ps=[4.074, 3.980, -0.254, 0.212]
5: eta=0.161051 cost=0.7840 jtype=constrained costheta=-0.996 ps=[4.138, 3.987, -0.386, 0.216]
6: eta=0.177156 cost=0.5104 jtype=constrained costheta=-0.988 ps=[4.235, 4.001, -0.513, 0.224]
7: eta=0.194872 cost=0.4369 jtype=constrained costheta=-0.684 ps=[4.362, 4.018, -0.558, 0.338]
8: eta=0.214359 cost=0.4170 jtype=constrained costheta=-0.861 ps=[4.410, 3.995, -0.514, 0.520]
9: eta=0.235795 cost=0.3853 jtype=constrained costheta=-0.917 ps=[4.522, 3.990, -0.419, 0.676]
10: eta=0.259374 cost=0.3340 jtype=constrained costheta=-0.848 ps=[4.685, 3.997, -0.323, 0.817]
11: eta=0.285312 

0.05733695207051093

In [72]:
figure(1); ax =subplot(6,1,1)
setp(ax, xticks=[])

func = (;pars...) -> run_dynamics(;ntrials=100, target1=0.8, target2=0.8, target3=0.8, 
target4=0.7, target5=0.5, target6=0.7, dt=2, noisefr=0.005, pars...)[1]

figure(1); clf();
func(; do_plot=true, ntrials=1000, plot_trials=1:10, make_dict(args, params)...)


0.05716030970507975

In [12]:
function make_dict(args, x::Vector)
    kwargs = Dict();    
    for i in [1:length(args);]    
        kwargs = merge(kwargs, Dict(Symbol(args[i]) => x[i]))        
    end    
    return kwargs
end 


make_dict (generic function with 1 method)

In [57]:
figure(1); clf();
func(; do_plot=true, ntrials=10, plot_trials=1:10, delay_period=220, # noisefr=0.005*sqrt(0.5), dt=5, 
    sigma=0.1, dt=1, noisefr=0,
    opto_effect=0.948, 
    make_dict(args, [6.336, -1.292, 0.274, 0.310])...)

0.7094960145662674

In [None]:
params

In [5]:
eta = 0.5;

params1=["vwi", "hwi", "pro_bias", "delay_input"];
params2=[4.0,4.0,0.2,0.2];


out = DiffBase.GradientResult(params2)  # out must be same length as whatever we will differentiate w.r.t.
keyword_gradient!(out, (;pars...) -> run_dynamics(;pars...)[1], params1, params2)  # note initial values must be floats
grad = DiffBase.gradient(out)
cost    = DiffBase.value(out)

badstuff,results=run_dynamics(vwi=params2[1],hwi=params2[2],pro_bias=params2[3],delay_input=params2[4])



i=0; 
while eta > 1e-6

    i=i+1
    new_params2 = params2 - eta*grad

    out = DiffBase.GradientResult(new_params2)  # out must be same length as whatever we will differentiate w.r.t.
    keyword_gradient!(out, (;pars...) -> run_dynamics(;pars...)[1], params1, new_params2)  # note initial values must be floats
    grad = DiffBase.gradient(out)
    new_cost    = DiffBase.value(out)

    new_cost2,new_results=run_dynamics(vwi=new_params2[1],hwi=new_params2[2],pro_bias=new_params2[3],delay_input=new_params2[4])    
    if abs(new_cost-new_cost2)>0.0001
        println((new_cost-new_cost2)/new_cost)
        error("yyy")
    end

    new_grad=grad;
    if new_cost < cost
        params2 = new_params2
        cost   = new_cost
        grad   = new_grad
        results = new_results
        eta = eta*1.1
    else    
        eta = eta/2
    end


    if rem(i, 100)==0
        @printf "%d: eta=%f, cost=%.5f, params=[%.3f, %.3f, %.3f, %.3f]\n" i eta cost params2[1] params2[2] params2[3] params2[4]
        # println("eta")
        # println(eta)
        # println("cost")
        # println(cost)
        # println("params")
        # println(params2)
        println("GRADIENT")
        println(grad)
        println("RESULTS")
        println(results)
        println("*********************")
        println("*********************")
    end
end







100: eta=0.022881, cost=0.14007, params=[4.834, 3.819, -0.027, 0.942]
GRADIENT
[-0.114364,0.0568412,-0.0154793,-0.0819561]
RESULTS
[0.761156 0.846898 0.80333 0.4524 0.499643 0.322676]
*********************
*********************
200: eta=0.063229, cost=0.10024, params=[5.060, 3.519, 0.006, 1.134]
GRADIENT
[-0.00649711,0.0379099,-0.00110136,-0.00700344]
RESULTS
[0.858891 0.92714 0.842813 0.513528 0.576572 0.363587]
*********************
*********************
300: eta=0.004433, cost=0.09502, params=[5.071, 3.340, 0.006, 1.136]
GRADIENT
[0.0232605,0.0218964,-0.174215,0.0112899]
RESULTS
[0.882063 0.922636 0.834386 0.537525 0.586192 0.373473]
*********************
*********************
400: eta=0.001994, cost=0.09202, params=[5.084, 3.208, 0.010, 1.117]
GRADIENT
[0.00322983,0.0184863,-0.0434583,0.00600916]
RESULTS
[0.894429 0.921359 0.830187 0.539561 0.577592 0.373146]
*********************
*********************
500: eta=0.869254, cost=0.09045, params=[5.097, 3.129, 0.013, 1.100]
GRADIENT
[-

LoadError: InterruptException: