# Preliminaries

In [1]:
# import ForwardDiff
using PyCall
# import PyPlot
using PyPlot
using ForwardDiff
using DiffBase

pygui(true)

import Base.convert
convert(::Type{Float64}, x::ForwardDiff.Dual) = Float64(x.value)
function convert(::Array{Float64}, x::Array{ForwardDiff.Dual}) 
    y = zeros(size(x)); 
    for i in 1:prod(size(x)) 
        y[i] = convert(Float64, x[i]) 
    end
    return y
end

include("hessian_utils.jl")

"""
We define functions to convert Duals, the variable types used by ForwardDiff, 
to Floats. This is useful if we want to print out the value of a variable 
(since print doesn't know how to Duals). Note that after being converted to a Float, no
differentiation by ForwardDiff can happen!  e.g. after
    x = convert(Float64, y)
ForwardDiff can still differentiate y, but it can't differentiate x
"""





"We define functions to convert Duals, the variable types used by ForwardDiff, \nto Floats. This is useful if we want to print out the value of a variable \n(since print doesn't know how to Duals). Note that after being converted to a Float, no\ndifferentiation by ForwardDiff can happen!  e.g. after\n    x = convert(Float64, y)\nForwardDiff can still differentiate y, but it can't differentiate x\n"

# The main model dynamics function

Note that the documenation indicates some default values for the optional parameters; these values need to be updated to what the code actually says below. The actual defaults are much closer to what Marino's May 2017 model has

In [21]:
""" 
function t, U, V, W = run_dynamics(trial_type, params)

    Runs the 4-way mutual inhibition model
    
    OBLIGATORY PARAMS:
    ------------------

        trial_type    Must be either "pro" or "anti"
        params        A 4-element vector, whose components are interpreted, in turn, as
            vwi   vertical inhibitory weights between Pro and Anti units. 
                    "36" means a -36 connection.
            hwi   horizontal inhibitory weights between like units on the two sides. 
                    "36" means a -36 connection.
            const_pro_bias   Extra positive input to the two Pro units
            const_E          Constant positive input to all four units


    OPTIONAL PARAMS:
    ----------------
    
        U_rest = -1     Resting point for U in the absance of other inputs
        g_leak = 0.5    Mutliplies (U_rest - U) for the dynamics
        theta = 1       Threshold on U for sigmoidal transform from U to V 
        beta  = 1       Scaling on sigmoid going from U to V:   V = 0.5*tanh((U-theta)/beta) + 0.5
        dt=0.02         Timestep
        sigma=0.1,      added standard deviation on U per unit time
        rule_period = 0.5       in seconds
        delay_period = 0.5      in seconds. Opto will happen during this period
        target_period = 0.1     right_light_input and target_extra_E will happen during this period; 
                                Pro v Anti input will be turned off; 
        post_target_period = 0.5  target_extra_E will still be on, but right_side_input won't
        tau=0.1         Time constant of dynamics, in secs
        marino_tau = False   If True, tau applies only to leak term; if False, applies to entire dUdt
        start_U = [-7, -7, -7, -7]
        const_E = 0.15  Constant excitation added to all units
        right_light_input=1     Extra excitation to right side of the brain units during the target period
        right_light_pro_extra   Even further excitation added to pro Right side units during target
        vwi = 1.5       Weight between ProContra and AntiIpsi units (on each side of the brain)
        hwi = 1.5       Weight between ProContra units across the brain; also between AntiContra
        pro_self_ex = 0   Self excitation weight of Pro units
        anti_self_ex = 0  Self excitation weight of Anti units
        pfc_anti_input = 0.05    Input to Anti units during rule and delay periods
        pfc_pro_input = np.nan   Input to Pro units during rule and delay periods (default means same as Anti)
        const_pro_bias = 0       A constant extra input to the Pro units
        target_extra_E = 0.25    Extra excitation added to all units during target and post_target periods
        opto='off'    Whether there is optogenetic-induced scaling of outputs.
            ='on'     Opto will be done during the delay period only
            ='dt'     Opto will be done during the delay plus the target period
           opto_scaling=0.5      Factor by which to scale the weight matrix during opto 
           opto_scale_on_E=1     Factor by which to scale the constant excitation during opto
           opto_conductance=0    How much conductance to add to gleak during opto
           opto_current=0        Added to dUdt at each time step during opto
        unilateral_opto = False If True, then opto_scale_on_E will be forced to 0, and opto_scaling will apply
                                to only one side
        do_plot = True  whether or not to plot the results
        fignum=1        figure on which to plot
        decision_threshold   If |V(Pro_R) - V(Pro_L)| >= this number, a proper answer is produced. 
                The target light is presented to the right, so Right means "pro"


    RETURNS:
    --------

        response     +1 for Pro, -1 for Anti, 0 for undefined 
                        if |V(Pro_R) - V(Pro_L)| < decision_threshold
        t    Time vector
        U   U matrix, size 4-by-len(t). Order is ProContra on right side, AntiIpsi on right
                    side, ProContra on left side, AntiIpsi on left.
        V   V matrix = 0.5*np.tanh((U-theta)/beta) + 0.5
        W   Weight matrix between units
        
"""
function run_dynamics(trial_type, params::Vector ; opto="off", opto_scaling=0.8, opto_scale_on_E=1,
    opto_conductance = 0, opto_current=0, 
    right_light_pro_extra = 0, right_light_input=12, 
    pro_self_ex = 0, anti_self_ex = 0, 
    tau=4.4, marino_tau = true, dt=0.05, target_extra_E = 0,
    pfc_anti_input = 1.6, pfc_pro_input = 0.05,
    sigma=3.2, start_U = [-25, -25, -25, -25], do_plot = false, fignum=1, 
    g_leak = 1, U_rest = 0, theta = 5, beta = 50,
    rule_period = 200, delay_period = 200, target_period = 50,
    post_target_period = 0.01, decision_threshold = 0.3)
    
    vwi = params[1]; hwi = params[2]; const_pro_bias = params[3]; const_E = params[4]
    
    t = [0 : dt : rule_period + delay_period + target_period + post_target_period;] 

    V = zeros(eltype(params), 4, length(t))   # the eltype(params) is for ForwardDiff
    U = zeros(eltype(params), 4, length(t))

    U[:,1] = start_U

    W = [pro_self_ex -vwi -hwi 0; -vwi anti_self_ex 0 -hwi; 
        -hwi 0 pro_self_ex -vwi; 0 -hwi -vwi anti_self_ex]

    E = const_E
    
    for i in [2:length(t);]  # the funny semicolon appears to be necessary in Julia
        if marino_tau
            dUdt = E + W * V[:,i-1] + g_leak*(U_rest - U[:,i-1])/tau
        else
            dUdt = E + W * V[:,i-1] + g_leak*(U_rest - U[:,i-1])
        end
    
        if t[i] < rule_period + delay_period
            if trial_type=="anti"
                dUdt[[2,4]] += pfc_anti_input
            elseif trial_type == "pro"
                dUdt[[1,3]] += pfc_pro_input
            else
            end
            
        elseif t[i] < rule_period + delay_period + target_period
            dUdt[[1,2]] += right_light_input
            dUdt[1]     += right_light_pro_extra
            dUdt        += target_extra_E
        else
            dUdt        += target_extra_E
        end
    
        dUdt[[1,3]] += const_pro_bias
        
        if marino_tau
            try
                U[:,i] = U[:,i-1] +       dt*dUdt + sigma*randn(4)*sqrt(dt)
            catch
                @printf "yep\n"
            end
        else
            U[:,i] = U[:,i-1] + (dt/tau)*dUdt + sigma*randn(4)*sqrt(dt)
        end
    
        V[:,i] = 0.5*tanh((U[:,i]-theta)/beta) + 0.5
    end    

    if do_plot
        figure(fignum); 
        subplot(3,1,1)
        h = plot(t, V'); 
        setp(h[1], color=[0, 0, 1])
        setp(h[2], color=[1, 0, 0])
        setp(h[3], color=[1, 0.5, 0.5])
        setp(h[4], color=[0, 1, 1])
        ylabel("V")

        ax = gca()
        yl = [ylim()[1], ylim()[2]]
        vlines([rule_period, rule_period+delay_period, 
            rule_period+delay_period+target_period], 
            -0.05, 1.05, linewidth=2)
        if yl[1]<0.02
            yl[1] = -0.02
        end
        if yl[2]>0.98
            yl[2] = 1.02
        end
        ylim(yl)
        grid(true)
        
        subplot(3,1,2)
        hu = plot(t, U')
        setp(hu[1], color=[0, 0, 1])
        setp(hu[2], color=[1, 0, 0])
        setp(hu[3], color=[1, 0.5, 0.5])
        setp(hu[4], color=[0, 1, 1])
        ylabel("U"); ylim(-100, 100)
        vlines([rule_period, rule_period+delay_period, 
            rule_period+delay_period+target_period], 
            ylim()[1], ylim()[2], linewidth=2)

        grid(true)
    
        subplot(3,1,3)
        hr = plot(t, V[1,:] - V[3,:])
        ylim([-0.9, 0.9])
        vlines([rule_period, rule_period+delay_period, 
            rule_period+delay_period+target_period], 
            ylim()[1], ylim()[2], linewidth=2)
        xlabel("t"); ylabel("Pro R - Pro L")
        grid(true)
        
    end
    
#    if V[1,end] - V[3,end] > decision_threshold
#        answer = 1
#    elseif V[1,end] - V[3,end] < -decision_threshold
#        answer = -1
#    else
#        answer = 0
#    end

    answer1 = 0.5 + 0.5*tanh(((V[1,end] - V[3,end]) - decision_threshold)/0.1)
    answer2 = 0.5 + 0.5*tanh(((V[3,end] - V[1,end]) - decision_threshold)/0.1)
    answer  = answer1 - answer2
    
    return answer, t, U, V, W 
end





run_dynamics

# Run the dynamics with Marino's parameters just to test it

In [23]:
trial_type = "pro"
params = [36, 1.8, 0.854, 1]
ntrys = 4

figure(1); clf();
for i in [1:ntrys;]
    answer, t, U, V, W = run_dynamics("pro", params, do_plot=true, sigma=0.4, 
    pfc_anti_input = 1.6, pfc_pro_input = 0.05, fignum=1)
    # println(answer)
end

figure(2); clf();
for i in [1:ntrys;]
    answer, t, U, V, W = run_dynamics("anti", params, do_plot=true, sigma=0.4,
    pfc_anti_input = 1.6, pfc_pro_input = 0.05, fignum=2)
    # println(answer)
end


# Here we play with testing with differentiating the main dynamics function, and defining a cost function

In [None]:


"""
function test_params(params; ntrials=10, sigma=3.2, random_seed=321)

Computes the final pro unit outputs given params. 
When run with the same parameters as J() above, including the random seed,
will go through the exact same sequence of trials and dynamics and responses.

OBLIGATORY PARAMS:
------------------

params     The same obligatory vector as in run_dynamics()

OPTIONAL PARAMS:
----------------

ntrials=10
sigma=3.2
random_seed=321

RETURNS:
--------

pro_answers     A 2-by-ntrials vector. First row is final value of Pro R unit in "pro" trials; second row is final value of Pro L unit
anti_answers    As pro_answers, but for anti trials

"""
function test_params(params, targets; ntrials=10, sigma=3.2, random_seed=321)
    srand(random_seed)
    cost = 0;
    pro_answers  = zeros(eltype(params), 2, ntrials)
    anti_answers = zeros(eltype(params), 2, ntrials)
    cost      = 0;
    pro_perf  = 0;
    anti_perf = 0;
    
    for i in [1:ntrials;]
        answer, t, U, V, W = run_dynamics("pro", params, do_plot=false, sigma=sigma) # , pfc_anti_input = 0.8, pfc_pro_input=0.8)
        pro_perf += V[1,end] - V[3,end]        
        cost += (V[1,end]- V[3,end])^2
        pro_answers[1,i] = V[1,end]
        pro_answers[2,i] = V[3,end]
        
        answer, t, U, V, W = run_dynamics("anti", params, do_plot=false, sigma=sigma) # , pfc_anti_input = 0.8, pfc_pro_input=0.8)
        anti_perf += V[3,end] - V[1,end]        
        cost += (V[1,end]- V[3,end])^2
        anti_answers[1,i] = V[1,end]
        anti_answers[2,i] = V[3,end]
    end

    RMSp = sqrt(sum((pro_answers[1,:]  - pro_answers[2,:]).^2) / ntrials)
    RMSa = sqrt(sum((anti_answers[1,:] - anti_answers[2,:]).^2) / ntrials)

    # @printf "pro_perf = %f   ntrials=%d   targets[1]=%f \n" pro_perf ntrials targets[1]
    cost1 = (pro_perf - ntrials*targets[1])^2 + (anti_perf - ntrials*targets[2])^2 
    cost2 = - cost

    return pro_answers, anti_answers, (RMSa+RMSp)/2, cost1, cost2
end




"""
FIX:  THIS IS THE DOC FOR J(), NOT JCOST2() !!!!

function J(params, targets; ntrials=10, sigma=3.2, random_seed=321)

Computes a cost function for certain parameters and fraction correct targets

OBLIGATORY PARAMS:
------------------

params     The same obligatory vector as in run_dynamics()
targets    A 2-element vector, first element is target fraction correct for Pro, second for Anti. E.g., [0.8 0.7]

OPTIONAL PARAMS:
----------------

ntrials=10
sigma=3.2
random_seed=321

RETURNS:
--------

cost      A scalar
"""
function Jcost2(params, targets; ntrials=20, theta1=0.15, theta2=0.15, beta=0.5, verbose=false)
    pro_answers, anti_answers = test_params(params, targets, ntrials=ntrials)
    
    pro_out = 0.5*(1 + tanh.((pro_answers[1,:]  - pro_answers[2,:])/theta1))
    ant_out = 0.5*(1 + tanh.((anti_answers[2,:] - anti_answers[1,:])/theta1))

    pro_dif = tanh((pro_answers[1,:]  - pro_answers[2,:]) /theta2).^2
    ant_dif = tanh((anti_answers[1,:] - anti_answers[2,:])/theta2).^2
    
    cost1 = (mean(pro_out) - targets[1])^2 + (mean(ant_out) - targets[2])^2
    cost2 = -mean(pro_dif) 
    cost2 -= mean(ant_dif)
    
    if verbose
        @printf("                              cost1=%.3f, cost2=%.3f, pro_out=%.3f, anti_out=%.3f\n", convert(Float64, cost1),
        beta*convert(Float64, cost2), convert(Float64, mean(pro_out)), convert(Float64, mean(ant_out)))
    end
    
    return cost1 + beta*cost2, pro_out, ant_out, pro_dif, ant_dif
end


targets = [0.8, 0.7]   # Fraction correct in Pro and Anti
params  = [36, 1, 0.854, 1.8]

println(Jcost2(params, targets))

grad = ForwardDiff.gradient(x -> Jcost2(x, targets)[1], params)


# Manual test of the derivatives for sanity check

In [None]:
# Does this match what we got above?

delta = 0.001; i=4; params2 = copy(params); params2[i] += delta; (Jcost2(params2, targets)[1] - Jcost2(params, targets)[1])/delta

In [None]:
# Getting the derivatives doesn't seem to add too much time, only about 60%:

@time(Jcost2(params, targets, ntrials=20))
@time(ForwardDiff.gradient(x -> Jcost2(x, targets, ntrials=20)[1], params))

In [None]:
# === RUN THE MINIMIZATION WITH CONSTRAINED HESSIAN MINIMIZATION ===
#

targets = [0.8, 0.7]  # Pro ant Anti desired per cent correct, respectively

ntrials=20
theta1=0.15
theta2=0.2

params = @time(constrained_Hessian_minimization([2.0, 2.0, 1, 1], 
x -> Jcost2([exp(x[1]), exp(x[2]), x[3], x[4]], targets, ntrials=ntrials, theta1=theta1, theta2=theta2, beta=0.05, verbose=true)[1], 
    verbose=true, tol=1e-7, start_eta=2.0))

In [None]:
# TEST THE RESULTS GIVEN PARAMS, TARGETS, NTRIALS, THETA1, THETA2

tc, pro_out, ant_out, pro_dif, ant_dif = Jcost2(params, targets, 
    ntrials=ntrials, theta1=theta1, theta2=theta2, beta=0.05, verbose=true);

figure(1); clf();
subplot(4,1,1)
plot(1:ntrials, pro_out, "b.")
subplot(4,1,2);
plot(1:ntrials, ant_out, "r.")
subplot(4,1,3)
plot(1:ntrials, pro_dif, "b.")
subplot(4,1,4);
plot(1:ntrials, ant_dif, "r.")

# Should really fold reporting this proans and antans into Jcost2, we don't need to run the model twice!
proans, antans = test_params(params, targets, ntrials=ntrials)

figure(2); clf();
subplot(2,1,1)
plot(1:ntrials, proans[1,:], "b.")
plot(1:ntrials, proans[2,:], "r.")
title("pro")
subplot(2,1,2)
plot(1:ntrials, antans[1,:], "b.")
plot(1:ntrials, antans[2,:], "r.")
title("anti")


In [None]:
# === RUN THE MINIMIZATION WITH ADAPTIVE GRADIENT DESCENT ===
#

targets = [0.8, 0.7]  # Pro ant Anti desired per cent correct, respectively

ntrials=20
theta1=0.15
theta2=0.2

params = @time(adaptive_gradient_minimization([18, 1.8, 0.854, 1], 
x -> Jcost2(x, targets, ntrials=ntrials, theta1=theta1, theta2=theta2, beta=0.05, verbose=true)[1], 
verbose=true, tol=1e-5, start_eta=2.0))

In [None]:
params = @time(constrained_Hessian_minimization([18, 1.8, 0.854, 1], 
x -> Jcost2(x, targets, ntrials=ntrials, theta1=theta1, theta2=theta2, beta=0.05, verbose=true)[1], 
    verbose=true, tol=1e-5, start_eta=2.0))

# TESTING PASSING KWARGS TO THE DIFFERENTIATION

In [None]:
function tester(; a=10, b=20, c=30)
    @printf "eltype(a)=%s\n" eltype(a)
    @printf "eltype(b)=%s\n" eltype(b)
    @printf "eltype(c)=%s\n" eltype(b)
    return a*10 + b*20 + c*30
end

function cost_function(;args...)
    return tester(;args...)
end

In [None]:
function tester2(x::Vector; par1=10, par2=20, pars...)
    println(pars)
    @printf "eltype is: %s" eltype(x)
    return x[1]+ 10*x[2]+100*x[3]^2
end

ForwardDiff.hessian(x->tester2(x, par1=5), [1.0,0, 4])

In [None]:

function make_dict(args, x::Vector)
    kwargs = Dict();
    for i in [1:length(args);]
        kwargs = merge(kwargs, Dict(Symbol(args[i]) => x[i]))
    end
    return kwargs
end

function ForwardDiffZeros(m, n; nderivs=0, derivorder=0)
    if nderivs == 0 || derivorder == 0
        return zeros(m, n)
    elseif derivorder == 1
        return zeros(ForwardDiff.Dual{nderivs, Float64}, m , n)
    elseif derivorder == 2
        return zeros(ForwardDiff.Dual{nderivs, ForwardDiff.Dual{nderivs, Float64}}, m, n)
    else
        error("Don't know how to do that order of derivatives!", nderivs)
    end
end
                

args = ["a", "b"]
init = [200, 300]
make_dict(args, init)

# ForwardDiffZeros(1, 2, nderivs=2, derivorder=1)

In [None]:
function tester(; a=10, b=20, c=30)
    # @printf "eltype(a)=%s\n" eltype(a)
    # @printf "eltype(b)=%s\n" eltype(b)
    # @printf "eltype(c)=%s\n" eltype(b)
    return a*10 + b*20 + c*30
end

function cost_function(;args...)
    return tester(;args...)
end

In [None]:
args = ["c", "b"]
init = [200.0, 3000]


# ForwardDiff.gradient(x -> tester(;make_dict(args, x)...), init)
ForwardDiff.gradient(x -> cost_function(;make_dict(args, x)...), init)


In [None]:
typeof(eltype(ForwardDiff.Dual{2,Float64}))

In [None]:

"""
function test_params2(targets; taking_derivative=true, ntrials=10, sigma=3.2, random_seed=321, params...)

Computes the final pro unit outputs given params. 
When run with the same parameters as J() above, including the random seed,
will go through the exact same sequence of trials and dynamics and responses.

OBLIGATORY PARAMS:
------------------

params     The same obligatory vector as in run_dynamics()

OPTIONAL PARAMS:
----------------

ntrials=10
sigma=3.2
random_seed=321

RETURNS:
--------

pro_answers     A 2-by-ntrials vector. First row is final value of Pro R unit in "pro" trials; second row is final value of Pro L unit
anti_answers    As pro_answers, but for anti trials

"""
function test_params2(targets; ntrials=10, sigma=3.2, random_seed=321, nderivs=0, derivorder=0, params...)
    srand(random_seed)
    cost = 0;
    pro_answers  = ForwardDiffZeros(2, ntrials, nderivs=nderivs, derivorder=derivorder)
    anti_answers = ForwardDiffZeros(2, ntrials, nderivs=nderivs, derivorder=derivorder)

    cost      = 0;
    pro_perf  = 0;
    anti_perf = 0;
    
    for i in [1:ntrials;]
        answer, t, U, V, W = run_dynamics2("pro", do_plot=false, 
        nderivs=nderivs, derivorder=derivorder, sigma=sigma; params...) # , pfc_anti_input = 0.8, pfc_pro_input=0.8)
        pro_perf += V[1,end] - V[3,end]        
        cost += (V[1,end]- V[3,end])^2
        pro_answers[1,i] = V[1,end]
        pro_answers[2,i] = V[3,end]
        
        answer, t, U, V, W = run_dynamics2("anti", do_plot=false, 
        nderivs=nderivs, derivorder=derivorder, sigma=sigma; params...) # , pfc_anti_input = 0.8, pfc_pro_input=0.8)
        anti_perf += V[3,end] - V[1,end]        
        cost += (V[1,end]- V[3,end])^2
        anti_answers[1,i] = V[1,end]
        anti_answers[2,i] = V[3,end]
    end

    RMSp = sqrt(sum((pro_answers[1,:]  - pro_answers[2,:]).^2) / ntrials)
    RMSa = sqrt(sum((anti_answers[1,:] - anti_answers[2,:]).^2) / ntrials)

    # @printf "pro_perf = %f   ntrials=%d   targets[1]=%f \n" pro_perf ntrials targets[1]
    cost1 = (pro_perf - ntrials*targets[1])^2 + (anti_perf - ntrials*targets[2])^2 
    cost2 = - cost

    return pro_answers, anti_answers, (RMSa+RMSp)/2, cost1, cost2
end



function Jcost3(targets; ntrials=20, theta1=0.15, theta2=0.15, beta=0.5, verbose=false, params...)
    pro_answers, anti_answers = test_params2(targets, ntrials=ntrials; params...)
    
    pro_out = 0.5*(1 + tanh.((pro_answers[1,:]  - pro_answers[2,:])/theta1))
    ant_out = 0.5*(1 + tanh.((anti_answers[2,:] - anti_answers[1,:])/theta1))

    pro_dif = tanh((pro_answers[1,:]  - pro_answers[2,:]) /theta2).^2
    ant_dif = tanh((anti_answers[1,:] - anti_answers[2,:])/theta2).^2
    
    cost1 = (mean(pro_out) - targets[1])^2 + (mean(ant_out) - targets[2])^2
    cost2 = -mean(pro_dif) 
    cost2 -= mean(ant_dif)
    
    if verbose
        @printf("                              cost1=%.3f, cost2=%.3f, pro_out=%.3f, anti_out=%.3f\n", convert(Float64, cost1),
        beta*convert(Float64, cost2), convert(Float64, mean(pro_out)), convert(Float64, mean(ant_out)))
    end
    
    return cost1 + beta*cost2, pro_out, ant_out, pro_dif, ant_dif
end


# test_params2([0.8, 0.7], taking_derivative=true)
# Jcost3([0.8, 0.7], taking_derivative=false)


args = ["vwi", "hwi", "const_pro_bias", "const_E", "pfc_anti_input", "pfc_pro_input"]
params = [36, 1.8, 0.854, 1, 0.05, 0.05]
# args = ["vwi", "hwi", "const_pro_bias"]
# params = [36, 1.8, 0.854]
# args = ["hwi", "vwi", "const_pro_bias", "const_E"]
# params = [1.8, 36, 0.854, 1]

targets = [0.8, 0.7]

figure(1); clf();
run_dynamics2("pro"; make_dict(args, params)..., do_plot=true)[4][end];
# test_params2(targets)[1][1,1]

srand(300)
#ForwardDiff.hessian(x->run_dynamics2("anti"; make_dict(args, x)..., nderivs=length(args),
#    derivorder = 2, do_plot=true)[4][end], params)
#ForwardDiff.hessian(x -> test_params2(targets; nderivs=length(args),
#    derivorder=2, make_dict(args, x)...)[1][1,1], params)

@time ForwardDiff.hessian(x -> Jcost3(targets; nderivs=length(args), derivorder=2, 
    make_dict(args, x)...)[1], params)

In [None]:
i=1; dx = 0.001; new_params = copy(params); new_params[i] = params[i]+dx; srand(300); A = run_dynamics2("pro"; nderivs=0, make_dict(args, new_params)...)[4][end]; srand(300); B= run_dynamics2("pro"; nderivs=0, make_dict(args, params)...)[4][end]; (A-B)/dx

In [None]:
i=4; dx = 0.0001; new_params = copy(params); new_params[i] = params[i]+dx; srand(300); A = test_params2(targets; make_dict(args, new_params)...)[1][1,1]; srand(300); B = test_params2(targets; make_dict(args, params)...)[1][1,1]; (A-B)/dx

In [None]:
i=3; dx = 0.00001; new_params = copy(params); new_params[i] = params[i]+dx; srand(300); A = Jcost3(targets; make_dict(args, new_params)...)[1]; srand(300); B = Jcost3(targets; make_dict(args, params)...)[1]; (A-B)/dx

In [None]:
# === RUN THE MINIMIZATION WITH CONSTRAINED HESSIAN MINIMIZATION ===
#

targets = [0.8, 0.7]  # Pro ant Anti desired per cent correct, respectively

ntrials=20
theta1=0.15
theta2=0.2

args = ["vwi", "hwi", "const_pro_bias", "const_E"]
seed = [36, 1.8, 0.854, 1]
args = ["vwi", "hwi", "const_pro_bias", "const_E", "pfc_anti_input", "pfc_pro_input"]
seed = [36, 1.8, 0.854, 1, 1.6, 0.05]
seed = [1, 1, 1, 1, 1.6, 0.05]

# ForwardDiff.gradient(x -> Jcost3(targets; nderivs=length(args), ntrials=ntrials, theta1=theta1, theta2=theta2,
#    beta=0.05, verbose=true, make_dict(args, x)...)[1], seed)

# ForwardDiff.hessian(x->Jcost3(targets; nderivs=length(args), derivorder=2, 
#    ntrials=ntrials, theta1=theta1, theta2=theta2, beta=0.05, verbose=true, make_dict(args, x)...)[1], seed)

# params = @time(constrained_Hessian_minimization(seed, 
# x -> Jcost3(targets; nderivs=length(args), derivorder=2, ntrials=ntrials, theta1=theta1, theta2=theta2, 
# beta=0.05, verbose=true, make_dict(args, [exp(x[1]), exp(x[2]), x[3], x[4], x[5], x[6]])...)[1], verbose=true, tol=1e-7, start_eta=2.0))

In [None]:
  

        

# ================ now test them

function tester(;a=10, b=20, c=30, nderivs=0, difforder=0)
    M = ForwardDiffZeros(3, 3; nderivs=nderivs, difforder=difforder)
    M[1,1] = a^2*10
    M[2,2] = b^3*20
    M[3,3] = a*sqrt(c)*30.1
    return trace(M)
end

res1 = keyword_gradient((;pars...) -> tester(;pars...), ["a", "b", "c"], [10, 20, 3.1])

out = DiffBase.GradientResult([10, 20, 30.1])
keyword_gradient!(out, (;pars...) -> tester(;pars...), ["a", "b", "c"], [10, 20, 3.1])
res2 = DiffBase.gradient(out)

println(res1)
println(res2)

res3 = keyword_hessian((;pars...) -> tester(;pars...), ["a", "c"], [1.0, 2.0])
out = DiffBase.HessianResult([10, 20.1])
keyword_hessian!(out, (;pars...) -> tester(;pars...), ["a", "c"], [1.0, 2.0])
res4 = DiffBase.hessian(out)

println(res3)
println(res4)


In [7]:
function tester(;a=10, b=20, c=30, nderivs=0, difforder=0)
    M = ForwardDiffZeros(3, 3; nderivs=nderivs, difforder=difforder)
    M[1,1] = a^2*10
    M[2,2] = b^3*20
    M[3,3] = a*sqrt(c)*30.1
    return trace(M)
 end

hess_b_c = keyword_hessian((;pars...) -> tester(;pars...), ["b", "c"], [10, 3.1])  # note initial values must be floats


hess_a_b_c = keyword_hessian((;pars...) -> tester(;pars...), ["a", "b", "c"], [10, 2, 3.1])



3×3 Array{Float64,2}:
 20.0        0.0    8.54783
  0.0      240.0    0.0    
  8.54783    0.0  -13.7868 

In [19]:
keyword_hessian((;pars...)->tester(;pars...), ["a", "c"], [1.1, 2.2])

out = DiffBase.HessianResult([10, 20.1])
keyword_hessian!(out, (;pars...) -> tester(;pars...), ["a", "c"], [1.0, 2.0])
DiffBase.hessian(out)

2×2 Array{Float64,2}:
 20.0    10.642  
 10.642  -2.66049

In [None]:
""" 
function response, t, U, V, W = run_dynamics2(trial_type)

    Runs the 4-way mutual inhibition model
    
    OBLIGATORY PARAMS:
    ------------------

        trial_type    Must be either "pro" or "anti"

    OPTIONAL PARAMS:
    ----------------
    
        vwi = 36        vertical inhibitory weights between Pro and Anti units. 
                          "36" means a -36 connection.
        hwi=1.8         horizontal inhibitory weights between like units on the two sides. 
        const_pro_bias=0.854   Extra positive input to the two Pro units
        const_E=1              Constant positive input to all four units
        U_rest = -1     Resting point for U in the absance of other inputs
        g_leak = 0.5    Mutliplies (U_rest - U) for the dynamics
        theta = 1       Threshold on U for sigmoidal transform from U to V 
        beta  = 1       Scaling on sigmoid going from U to V:   V = 0.5*tanh((U-theta)/beta) + 0.5
        dt=0.02         Timestep
        sigma=0.1,      added standard deviation on U per unit time
        rule_period = 0.5       in seconds
        delay_period = 0.5      in seconds. Opto will happen during this period
        target_period = 0.1     right_light_input and target_extra_E will happen during this period; 
                                Pro v Anti input will be turned off; 
        post_target_period = 0.5  target_extra_E will still be on, but right_side_input won't
        tau=0.1         Time constant of dynamics, in secs
        marino_tau = False   If True, tau applies only to leak term; if False, applies to entire dUdt
        start_U = [-7, -7, -7, -7]
        const_E = 0.15  Constant excitation added to all units
        right_light_input=1     Extra excitation to right side of the brain units during the target period
        right_light_pro_extra   Even further excitation added to pro Right side units during target
        vwi = 1.5       Weight between ProContra and AntiIpsi units (on each side of the brain)
        hwi = 1.5       Weight between ProContra units across the brain; also between AntiContra
        pro_self_ex = 0   Self excitation weight of Pro units
        anti_self_ex = 0  Self excitation weight of Anti units
        pfc_anti_input = 0.05    Input to Anti units during rule and delay periods
        pfc_pro_input = np.nan   Input to Pro units during rule and delay periods (default means same as Anti)
        const_pro_bias = 0       A constant extra input to the Pro units
        target_extra_E = 0.25    Extra excitation added to all units during target and post_target periods
        opto='off'    Whether there is optogenetic-induced scaling of outputs.
            ='on'     Opto will be done during the delay period only
            ='dt'     Opto will be done during the delay plus the target period
           opto_scaling=0.5      Factor by which to scale the weight matrix during opto 
           opto_scale_on_E=1     Factor by which to scale the constant excitation during opto
           opto_conductance=0    How much conductance to add to gleak during opto
           opto_current=0        Added to dUdt at each time step during opto
        unilateral_opto = False If True, then opto_scale_on_E will be forced to 0, and opto_scaling will apply
                                to only one side
        do_plot = True  whether or not to plot the results
        fignum=1        figure on which to plot
        decision_threshold   If |V(Pro_R) - V(Pro_L)| >= this number, a proper answer is produced. 
                The target light is presented to the right, so Right means "pro"
        taking_derivative = true    If true, will return ForwardDiff.Dual{, Float64} types; otherwise Float64 types


    RETURNS:
    --------

        response     +1 for Pro, -1 for Anti, 0 for undefined 
                        if |V(Pro_R) - V(Pro_L)| < decision_threshold
        t    Time vector
        U   U matrix, size 4-by-len(t). Order is ProContra on right side, AntiIpsi on right
                    side, ProContra on left side, AntiIpsi on left.
        V   V matrix = 0.5*np.tanh((U-theta)/beta) + 0.5
        W   Weight matrix between units
        
"""
function run_dynamics2(trial_type  ; opto="off", opto_scaling=0.8, opto_scale_on_E=1,
    vwi = 36, hwi = 1.8, const_pro_bias = 0.854, const_E = 1,
    opto_conductance = 0, opto_current=0, 
    right_light_pro_extra = 0, right_light_input=12, 
    pro_self_ex = 0, anti_self_ex = 0, 
    tau=4.4, marino_tau = true, dt=0.05, target_extra_E = 0,
    pfc_anti_input = 1.6, pfc_pro_input = 0.05,
    sigma=3.2, start_U = [-25, -25, -25, -25], do_plot = false, fignum=1, 
    g_leak = 1, U_rest = 0, theta = 5, beta = 50,
    rule_period = 200, delay_period = 200, target_period = 50,
    post_target_period = 0.01, decision_threshold = 0.3, nderivs=0, derivorder=0)
    
    t = [0 : dt : rule_period + delay_period + target_period + post_target_period;] 

    V = ForwardDiffZeros(4, length(t), nderivs=nderivs, derivorder=derivorder)   # the element type is for ForwardDiff obviously
    U = ForwardDiffZeros(4, length(t), nderivs=nderivs, derivorder=derivorder)

    U[:,1] = start_U

    W = [pro_self_ex -vwi -hwi 0; -vwi anti_self_ex 0 -hwi; 
        -hwi 0 pro_self_ex -vwi; 0 -hwi -vwi anti_self_ex]

    E = const_E
    
    for i in [2:length(t);]  # the funny semicolon appears to be necessary in Julia
        if marino_tau
            dUdt = E + W * V[:,i-1] + g_leak*(U_rest - U[:,i-1])/tau
        else
            dUdt = E + W * V[:,i-1] + g_leak*(U_rest - U[:,i-1])
        end
    
        if t[i] < rule_period + delay_period
            if trial_type=="anti"
                dUdt[[2,4]] += pfc_anti_input
            elseif trial_type == "pro"
                dUdt[[1,3]] += pfc_pro_input
            else
            end
            
        elseif t[i] < rule_period + delay_period + target_period
            dUdt[[1,2]] += right_light_input
            dUdt[1]     += right_light_pro_extra
            dUdt        += target_extra_E
        else
            dUdt        += target_extra_E
        end
    
        dUdt[[1,3]] += const_pro_bias
        
        if marino_tau
            try
                U[:,i] = U[:,i-1] +       dt*dUdt + sigma*randn(4)*sqrt(dt)
            catch
                @printf "yep\n"
            end
        else
            U[:,i] = U[:,i-1] + (dt/tau)*dUdt + sigma*randn(4)*sqrt(dt)
        end
    
        V[:,i] = 0.5*tanh((U[:,i]-theta)/beta) + 0.5
    end    

    if do_plot
        figure(fignum); 
        subplot(3,1,1)
        h = plot(t, V'); 
        setp(h[1], color=[0, 0, 1])
        setp(h[2], color=[1, 0, 0])
        setp(h[3], color=[1, 0.5, 0.5])
        setp(h[4], color=[0, 1, 1])
        ylabel("V")

        ax = gca()
        yl = [ylim()[1], ylim()[2]]
        vlines([rule_period, rule_period+delay_period, 
            rule_period+delay_period+target_period], 
            -0.05, 1.05, linewidth=2)
        if yl[1]<0.02
            yl[1] = -0.02
        end
        if yl[2]>0.98
            yl[2] = 1.02
        end
        ylim(yl)
        grid(true)
        
        subplot(3,1,2)
        hu = plot(t, U')
        setp(hu[1], color=[0, 0, 1])
        setp(hu[2], color=[1, 0, 0])
        setp(hu[3], color=[1, 0.5, 0.5])
        setp(hu[4], color=[0, 1, 1])
        ylabel("U"); ylim(-100, 100)
        vlines([rule_period, rule_period+delay_period, 
            rule_period+delay_period+target_period], 
            ylim()[1], ylim()[2], linewidth=2)

        grid(true)
    
        subplot(3,1,3)
        hr = plot(t, V[1,:] - V[3,:])
        ylim([-0.9, 0.9])
        vlines([rule_period, rule_period+delay_period, 
            rule_period+delay_period+target_period], 
            ylim()[1], ylim()[2], linewidth=2)
        xlabel("t"); ylabel("Pro R - Pro L")
        grid(true)
        
    end
    
#    if V[1,end] - V[3,end] > decision_threshold
#        answer = 1
#    elseif V[1,end] - V[3,end] < -decision_threshold
#        answer = -1
#    else
#        answer = 0
#    end

    answer1 = 0.5 + 0.5*tanh(((V[1,end] - V[3,end]) - decision_threshold)/0.1)
    answer2 = 0.5 + 0.5*tanh(((V[3,end] - V[1,end]) - decision_threshold)/0.1)
    answer  = answer1 - answer2
    
    return answer, t, U, V, W 
end


figure(1); clf();
run_dynamics2("pro", do_plot=true) # , nderivs=2, derivorder=2)


# OLD -- stuff for testing the constrained Hessian minimization

In [None]:
params = [18, 1.8, 0.854, 1]
func = x -> Jcost2(x, targets, ntrials=ntrials, theta1=theta1, theta2=theta2, beta=0.05, verbose=true)[1]

out = DiffBase.HessianResult(params)
ForwardDiff.hessian!(out, func, params)
cost = DiffBase.value(out)
grad = DiffBase.gradient(out)
hess = DiffBase.hessian(out)



In [None]:
grad''

In [None]:
    chessdelta = zeros(size(params))
        hessdelta  = - inv(hess)*grad

chessdelta = constrained_parabolic_minimization(hess, grad'', 2)[1] # , doplot=true, lambdastepsize=0.01, efactor=3)[1]
norm(hessdelta)

new_params = params + chessdelta
            ForwardDiff.hessian!(out, func, new_params)
            new_cost = DiffBase.value(out)
            new_grad = DiffBase.gradient(out)
            new_hess = DiffBase.hessian(out)

println(params)
println(chessdelta)
println(new_params)
ylim(-0.0001, 0.001); grid();
[cost new_cost]


In [None]:
println(dot(chessdelta,grad)/(norm(chessdelta)*norm(grad)))
println(chessdelta)
println(grad)
ylim(-0.1, 10)
diff(sort(eig(hess)[1]))


In [None]:
params

In [None]:
figure(2);
clf();
theta = 0.4
ant_out = (0.5*(1 + tanh.((antans[2,:] - antans[1,:])/theta)))
subplot(2,1,1)
plot(antans[2,:] - antans[1,:], "b."); grid()

subplot(2,1,2)
plot(ant_out, "b.")


sqrt(var(antans[2,:] - antans[1,:]))
grid()

# OLD SANDLOT FROM HERE ON

# Beginning to test gradient descent

It's kind of working?  Two issues:

(a) I think we're trapped in the final attractor values-- maybe time to explore adding reaction time, or not computing unit values so late in the trial? Actually, on further inspection, it is really asking for something like 80% performance without having defined outputs as hit=1, miss=0.  Really need the sigmoid.

(b) Right now I only know how to differentiate J when J computes a single scalar.  But sometimes we want to stash some values as we go. Don't yet know how to do that.

(c) Would also be nice to save values in a file or something while the gradient descent search is occurring, so as to have a trace of what happened, for later debugging

In [None]:
# trial_type = "pro"
# params = [36, 1.8, 0.854, 1]
# params = [36, 0.3, 0.254, 1]
targets = [0.8, 0.7]   # Fraction correct in Pro and Anti

ntrials = 10
# eta = 0.001;

# # This is close to Marino's params. Not much changes.
# params = [36.006098, 1.655863, 0.568341, 1.019982]; 
# eta = 0.000221
# params =[36.023674, 1.694372, 0.619004, 1.092042]; 
# eta = 0.00001

# # Starting from the one immediately below, which produces all-Pro responses,
# # the algorithm seems to be working!  Certainly the cost goes down, and we
# # get to 6 correct Pro, 8 correct Anti (after having started at [10,0]).
# params =[18.023674, 0.894372, 0.619004, 1.092042]; 
# eta = 0.0001
# # The next few are a few stops along the way as the algorithm went on. I halted it,
# # but it was still going and still getting better.
# params = [23.211391, 3.799635, -0.432648, 4.131169]
# params = [24.649633, 3.440672, -0.157887, 3.355422]
# eta = 0.013
# params = [26.855623, 3.449642, -0.222252, 2.867700]


eta = 0.004
params = [34.296820, 2.514509, 0.535576, 1.585929]

# 22050: eta=0.013335, cost=1.64289, [P,A : D : c1,c2]=[9,9 : 0.584 : 23.685, -7.256], 
# params=[36.384866, 2.597524, 0.570819, 0.942388]

# --------------

params = [34.296820, 2.514509, 0.535576, 1.585929] 
eta = 0.1

# --------------

out = DiffBase.GradientResult(params)
ForwardDiff.gradient!(out, x -> J(x, targets, ntrials=ntrials), params)
cost = DiffBase.value(out)
grad = DiffBase.gradient(out)

i=0; while eta > 1e-6

    i=i+1
    new_params = params - eta*grad
    
    ForwardDiff.gradient!(out, x -> J(x, targets, ntrials=ntrials), new_params)
    new_cost = DiffBase.value(out)
    new_grad = DiffBase.gradient(out)
    
    if new_cost < cost
        params = new_params
        cost   = new_cost
        grad   = new_grad
        eta = eta*1.1
    else
        eta = eta/2
    end
    
    pro_answers, anti_answers, D, c1, c2 = test_params(params, targets, ntrials=ntrials)
    P = sum((sign(pro_answers[1,:]  - pro_answers[2,:])+1)/2)
    A = sum((sign(anti_answers[2,:] - anti_answers[1,:])+1)/2)
    
    # P is number of times Pro  trials had Pro_R > Pro_L  (i.e., were correct). If this is equal to ntrials*target[1], we're golden
    # A is number of times Anti trials had Pro_R < Pro_L  (i.e., were correct). If this is equal to ntrials*target[2], we're golden
    # D is the root mean square separation between Pro_R and Pro_L, across all trials.  Large D is good.
    
    if rem(i, 1)==0
        @printf "%d: eta=%f, cost=%.5f, [P,A : D : c1,c2]=[%d,%d : %.3f : %.3f, %.3f], params=[%.3f, %.3f, %.3f, %.3f]\n" i eta cost P A D c1 c2 params[1] params[2] params[3] params[4]
    end
end

        


In [None]:
#  Plot out some final values to see what things are looking like

new_params = [26.855623, 3.449642, -0.222252, 2.867700]
pro_answers2, anti_answers2 = test_params(new_params, ntrials=10)

figure(3); clf()
subplot(2,1,1)
plot(pro_answers2[1,:], "b.", pro_answers2[2,:], "r.")
title("PRO")

subplot(2,1,2)
plot(anti_answers2[1,:], "b.", anti_answers2[2,:], "r.")
title("ANTI")

In [None]:
new_params =[23.211391, 3.799635, -0.432648, 4.131169]; new_eta = 0.000221


In [None]:
"""
function plot_many(params; ntrials=5, random_seed = 321, sigma=3.2, target_period=50, start_plotting_at=5)

Plot full dynamics of some example trials, using same random seed and trial order as was used during the minimization
"""
function plot_many(params; ntrials=5, random_seed = 321, sigma=3.2, target_period=50, start_plotting_at=5)
    srand(random_seed)    
    pro_perf  = 0;
    anti_perf = 0;
    
    for i in [1:start_plotting_at-1;]        
        answer, t, U, V, W = run_dynamics("pro", params, do_plot=true, fignum=2, 
            sigma=sigma, target_period=target_period) # , pfc_anti_input = 0.8, pfc_pro_input=0.8)
        pro_perf += V[1,end] - V[3,end]        

        answer, t, U, V, W = run_dynamics("anti", params, do_plot=true, fignum=3, 
            sigma=sigma, target_period=target_period) # , pfc_anti_input = 0.8, pfc_pro_input=0.8)
        anti_perf += V[3,end] - V[1,end]        
    end

    figure(2); clf();
    figure(3); clf();

    for i in [1:ntrials-start_plotting_at+1;]        
        answer, t, U, V, W = run_dynamics("pro", params, do_plot=true, fignum=2, 
            sigma=sigma, target_period=target_period) # , pfc_anti_input = 0.8, pfc_pro_input=0.8)
        pro_perf += V[1,end] - V[3,end]        

        answer, t, U, V, W = run_dynamics("anti", params, do_plot=true, fignum=3, 
            sigma=sigma, target_period=target_period) # , pfc_anti_input = 0.8, pfc_pro_input=0.8)
        anti_perf += V[3,end] - V[1,end]        
    end
    
    @printf "pro_perf=%.2f  anti_perf=%.2f"  pro_perf  anti_perf
    return pro_perf, anti_perf
end

In [None]:
params = [34.296820, 2.514509, 0.535576, 1.585929]
# params = [36, 1.8, 0.854, 1]
params=[36.384866, 2.597524, 0.570819, 0.942388]

plot_many(params, start_plotting_at=1,sigma=0.8, ntrials=10)

pa, aa, rms, c1, c2 = test_params(params, targets, ntrials=10)

println(c1)
println(c2)

# Wondering about playing with using the Hessian for faster searching...

# From here on, various pieces of trash:


In [None]:
# This doesn't work, has no effect

@pyimport matplotlib as mpl

mpl.rcParams["font.size"] = 32
mpl.rcParams["font.family"] = "Arial"
mpl.rcParams["lines.linewidth"] = 1.5
mpl.rcParams["lines.markersize"] = 8