# Preliminaries

In [1]:
using PyCall
using PyPlot
using ForwardDiff
using DiffBase

pygui(true)



true

# The main model dynamics function

Note that the documenation indicates some default values for the optional parameters; these values need to be updated to what the code actually says below. The actual defaults are much closer to what Marino's May 2017 model has

In [2]:


""" 
function t, U, V, W = run_dynamics(trial_type, params)

    Runs the 4-way mutual inhibition model
    
    Params: vert w; horiz w; pro bias; delay input
    
"""
function run_dynamics(trial_type, params::Vector ; opto_cue=1, opto_delay=1, opto_choice=1,
    light_input=12, noisefr=0.1, threshold=0.18,
    tau=4.4, dt=0.05, start_U = [-25, -25, -25, -25],
    g_leak = 1, U_rest = 0, theta = 5, beta = 50, do_plot = false, fignum=1,
    cue_period = 200, delay_period = 200, choice_period = 50)
    
    vwi = params[1]; hwi = params[2]; pro_bias = params[3]; delay_input = params[4]
    
    t = [0 : dt : cue_period + delay_period + choice_period;]

    V = zeros(eltype(params), 4, length(t))   # the eltype(params) is for ForwardDiff
    U = zeros(eltype(params), 4, length(t))

    U[:,1] = start_U

    W = [0 -vwi -hwi 0; -vwi 0 0 -hwi;
        -hwi 0 0 -vwi; 0 -hwi -vwi 0]


    for i in [2:length(t);]  # the funny semicolon appears to be necessary in Julia
        
        dUdt = W * V[:,i-1] + g_leak*(U_rest - U[:,i-1])/tau
        
        if t[i] < cue_period + delay_period
            if trial_type=="anti"
                dUdt[[2,4]] += delay_input
            elseif trial_type == "pro"
                dUdt[[1,3]] += delay_input
            else
                error("invalid trial type")
            end
            
        elseif t[i] < cue_period + delay_period + choice_period
            dUdt[[1,2]] += light_input
        end
    
        dUdt[[1,3]] += pro_bias
        
        

        U[:,i] = U[:,i-1] +  dt*dUdt

        V[:,i] = 0.5*tanh((U[:,i]-theta)/beta) + 0.5
        
        if t[i] < cue_period
            V[:,i]=V[:,i]*opto_cue
        elseif t[i] < cue_period+delay_period
            V[:,i]=V[:,i]*opto_delay
        elseif t[i] < cue_period+delay_period+choice_period
            V[:,i]=V[:,i]*opto_choice
        end

        V[:,i] = V[:,i] + noisefr*randn(4)*sqrt(dt)


    end







    if do_plot
        figure(fignum);
        h = plot(t, V');
        setp(h[1], color=[0, 0, 1])
        setp(h[2], color=[1, 0, 0])
        setp(h[3], color=[1, 0.5, 0.5])
        setp(h[4], color=[0, 1, 1])
         
        ax = gca()
        yl = [ylim()[1], ylim()[2]]
        vlines([cue_period, cue_period+delay_period,
                cue_period+delay_period+choice_period],
                0.05, 1.05, linewidth=2)
        if yl[1]<0.02
                 yl[1] = -0.02
        end
        if yl[2]>0.98
                 yl[2] = 1.02
        end
        ylim(yl)
        grid(true)
     end
                 





    answer  = V[1,end] - V[3,end]


    #compute reaction time

    reac=NaN;
    for i in [8001:length(t)-15;]
        val1=mean(V[1,i-15:i+15]);
        val3=mean(V[3,i-15:i+15]);
        if(abs(val1-val3)>threshold)
            reac=t[i];
            break
        end
    end




    return answer, reac, t, U, V
end



run_dynamics

# Run the dynamics with Marino's parameters just to test it

In [3]:



params = [36, 1, 0.854, 1]
ntrys = 5
opto_delay_use=1
opto_choice_use=1

                
do_plot_use=true;

                 
for i in [1:ntrys;]
    answer, reac, U, V = run_dynamics("pro", params,opto_delay=opto_delay_use,opto_choice=opto_choice_use,do_plot=do_plot_use,fignum=1)
    println(answer)
    println(reac)
    println("******************")
end



for i in [1:ntrys;]
answer, reac, U, V = run_dynamics("anti", params,opto_delay=opto_delay_use,opto_choice=opto_choice_use,do_plot=do_plot_use,fignum=2)
println(answer)
println(reac)
println("******************")
end

                
# show()

                 



0.5930530293725472
402.05
******************
0.5774225131020866
402.0
******************
0.6226678672508857
401.85
******************
0.5473068688826903
402.15
******************
0.5784606534618995
401.85
******************
-0.3048699289028684
428.6
******************
-0.38437661772862275
400.0
******************
0.701529045962731
400.8
******************
-0.3578939463526034
419.95
******************
-0.3299615763379042
427.9
******************


# Here we play with testing with differentiating the main dynamics function, and defining a cost function

In [4]:


"""
function J(params, targets; ntrials=10, sigma=3.2, random_seed=321)

Computes a cost function for certain parameters and fraction correct targets
"""
             
function J(params, targets; ntrials=10, noisefr=0.1, random_seed=321)
    
    srand(random_seed)
                 
    pro_perf  = 0;
    anti_perf = 0;
    
    pro_perf_opto_delay  = 0;
    anti_perf_opto_delay = 0;
    
    pro_perf_opto_choice  = 0;
    anti_perf_opto_choice = 0;

             
    for i in [1:ntrials;]
        answer, reac, t, U, V = run_dynamics("pro", params, noisefr=noisefr)
        pro_perf += V[1,end] - V[3,end]

        answer, reac, t, U, V = run_dynamics("anti", params, noisefr=noisefr)
        anti_perf += V[3,end] - V[1,end]
        
        
        answer, reac, t, U, V = run_dynamics("pro", params, noisefr=noisefr, opto_delay=0.95)
        pro_perf_opto_delay += V[1,end] - V[3,end]
        
        answer, reac, t, U, V = run_dynamics("anti", params, noisefr=noisefr, opto_delay=0.95)
        anti_perf_opto_delay += V[1,end] - V[3,end]
        
        
        answer, reac, t, U, V = run_dynamics("pro", params, noisefr=noisefr, opto_choice=0.95)
        pro_perf_opto_choice += V[1,end] - V[3,end]
        
        answer, reac, t, U, V = run_dynamics("anti", params, noisefr=noisefr, opto_choice=0.95)
        anti_perf_opto_choice += V[1,end] - V[3,end]
        
        
        
    end

    cost = (pro_perf - ntrials*targets[1])^2 + (anti_perf - ntrials*targets[2])^2 + 
    (pro_perf_opto_delay - ntrials*targets[3])^2 + (anti_perf_opto_delay - ntrials*targets[4])^2 +
    (pro_perf_opto_choice - ntrials*targets[5])^2 + (anti_perf_opto_choice - ntrials*targets[5])^2

    return cost/ntrials
                 
end
                 
                 
                 
params = [36, 1, 0.854, 1]

targets = [0.8, 0.7, 0.8, 0.5, 0.8, 0.7]   # Fraction correct in Pro and Anti

println(J(params, targets))

grad = ForwardDiff.gradient(x -> J(x, targets), params)

println(grad)
                 


20.887484490092884
[1.40753,-14.4571,-73.8544,78.3756]


# Beginning to test gradient descent

It's kind of working?  Two issues:

(a) I think we're trapped in the final attractor values-- maybe time to explore adding reaction time, or not computing unit values so late in the trial? Actually, on further inspection, it is really asking for something like 80% performance without having defined outputs as hit=1, miss=0.  Really need the sigmoid.

(b) Right now I only know how to differentiate J when J computes a single scalar.  But sometimes we want to stash some values as we go. Don't yet know how to do that.

(c) Would also be nice to save values in a file or something while the gradient descent search is occurring, so as to have a trace of what happened, for later debugging

In [None]:

                 
                 
 params = [36, 1, 0.854, 1]
 
 
 targets = [0.8, 0.7, 0.8, 0.5, 0.8, 0.7]   # Fraction correct in Pro and Anti
                 

 
 ntrials = 100
 
  eta = 0.001;
 
               
                 
 # --------------
 
 out = DiffBase.GradientResult(params)
 ForwardDiff.gradient!(out, x -> J(x, targets, ntrials=ntrials), params)
 cost = DiffBase.value(out)
 grad = DiffBase.gradient(out)
 
 i=0; while eta > 1e-6
 
 i=i+1
 new_params = params - eta*grad
 
 ForwardDiff.gradient!(out, x -> J(x, targets, ntrials=ntrials), new_params)
 new_cost = DiffBase.value(out)
 new_grad = DiffBase.gradient(out)
 
 if new_cost < cost
 params = new_params
 cost   = new_cost
 grad   = new_grad
 eta = eta*1.1
 else
 eta = eta/2
 end
 
    if rem(i, 1)==0
 @printf "%d: eta=%f, cost=%.5f, params=[%.3f, %.3f, %.3f, %.3f]\n" i eta cost params[1] params[2] params[3] params[4]
 end
 end
 
 
 
                 
       


