** DIFFERENTIABLE FUNCTIONS FOR RUNNING RATE-BASED NEURAL NETWORKS**

This notebook contains functions to run, both in forward and backward time a rate-based network model (using equations similar to those of Hopfield, PNAS, 1984 and many others), plus some simple examples using them. The code (without the illustrative examples) gets extracted into

    rate_networks.jl

There is also an example of using those networks together with ForwardDiff to differentiate one of the outputs with respect to various network parameters.

<h1 id="tocheading">TABLE OF CONTENTS</h1>
<div id="toc"></div>

**Updates to the table of contents are periodic, but run the cell below to first start or force an update.**

In [None]:
macro javascript_str(s) display("text/javascript", s); end

javascript"""
$.getScript('make_table_of_contents.js')
"""

In [None]:
#@include_me  rate_networks.jl

using PyCall
using PyPlot
using ForwardDiff
using DiffBase
using MAT

# pygui(true)

include("general_utils.jl")
include("constrained_parabolic_minimization.jl")
include("hessian_utils.jl")




# Setup -- definitions of forwardModel() and backwardsModel()

These are functions that run arbitrary $0.5 * (1+\tanh(x))$ - style rate networks, either forwards in time, or backwards in time.  The backwards in time part was for an idea that we're no longer pursuing, but was kept here for completeness.

The equations are similar to those in Hopfield, PNAS, 1984 and in many papers since: For unit $i$,

$$
    \tau \frac{{\rm d}U_i}{{\rm d}t} \; = \; g_{\rm leak} \cdot (U_{\rm rest} - U) \; + \; 
    \Sigma_j W_{ij} g(U_j) \; + I_i + \; \sigma \eta
$$

where

$$
    g(U) = 0.5\cdot \left(1 + \tanh\frac{U-\beta}{\theta} \right)
$$

The **forwards integration** is done using simple Euler integration 

$$ 
U(t+\delta t) =  U(t) + \delta t \frac{{\rm d}U}{{\rm d}t}.
$$

The **backwards integration** is more complicated, and involves doing a minimization search to find the $U(t-\delta t)$ that would most closely produce $U(t)$ after one $\delta t$ timestep.

In [None]:
#@include_me rate_networks.jl

    

"""
forwardModel(startU; dt=0.01, tau=0.1, nsteps=100, input=[0.1, 0], noise=[], W=[0 -5;-5 0], 
init_add=0, start_add=0, const_add=0, sigma=0, gleak=1, U_rest=0, 
    do_plot=false, nderivs=0, difforder=0, clearfig=true, fignum=1, dUdt_mag_only=false,
    warn_if_unused_params=false, opto_strength=1, opto_units=[], opto_times=zeros(0,2),)

Runs a tanh() style-network forwards in time, given its starting point, using simple Euler integration
    tau dU/dt = -U + W*V + I
    V = 0.5*tanh(U)+ 0.5

# PARAMETERS:

- startU     A column vector, nunits-by-1, indicating the values of U at time zero


# OPTIONAL PARAMETERS

- dt      Scalar, timestep size

- tau     Scalar, in seconds

- gleak   dUdt will have a term equal to gleak*(U_rest - U)

- U_rest  dUdt will have a term equal to gleak*(U_rest - U)

- nsteps  Number of timesteps to run, including time=0.

- input   Either an nunits-by-1 vector, in which case inputs to each unit are constant
        across time, or a matrix, nunits-by-nsteps, indicating input for each unit at each timepoint.

- W       Weight matrix, nunits-by-nunits

- init_add    DEPRECATED: Vector or scalar that gets added to the input current at very first timestep.
            Deprecated because this made it dt-dependent. Replaced by start_add.

- start_add   Vector or scalar that gets added, once, to the initial U[:,1], before the integration process begins.

- const_add   Scalar that gets added to U after every timestep

- sigma       After each timestep, add sigma*sqrt(dt)*randn() to each element of U

- opto_strength    The outputs V, after being computed, will get multiplied by this number. opto_strength should *EITHER* be a scalar, in which case optional params opto_units and opto_times below are also relevant; *OR* it should be an nunits-by-nsteps matrix, completely specifying how much each unit's V should be multiplied by at each timestep, in which case opto_times and opto_units are irrelevant

- opto_units       A list of the unit numbers that will have their V multiplied by opto_strength. For example, [1,3] would affect only units 1 and 3.  Can be the empty matrix (equivalent to no opto effect). Irrelevant if opto_strength = 1

- opto_times    An n-by-2 matrix, where each row lists t_start_of_opto_effect, t_end_of_opto_effect. For example,
                [1 3 ; 6 8]  would mean "have an opto effect during both 1 <= t <=3 and 6 <= t <= 8]. With the 
                code as currently configured, this would mean the same opto_strength and opto_units across all 
                the relevant time intervals in a run.

- do_plot   Default false, if true, plots V of up to the first two dimensions

- fignum     Figure number on which to plot

- clearfig  If true, the figure is first cleared, otherwise any plot ois overlaid

- nderivs, difforder     Required for making sure function can create its own arrays and 
                       still be differentiated

- dUdt_mag_only  If true, returns |dUdt|^2 from the first timestep only, then stops.

- warn_if_unused_params     If true, pronts out a warning of some of the passed parameters are not used.



** RETURNS:**

- Uend Vend       nunits-by-1 vectors representing the final values of U and V that were found.

- U, V            nunits-by-nsteps matrices containing the full trajectories

- t               A time vector, so one could things like plot(t, U[1,:])

"""
function forwardModel(startU; opto_strength=1, opto_units=[], opto_times=zeros(0,2),
    dt=0.01, tau=0.1, nsteps=100, input=[], noise=[], W=[0 -5;-5 0], 
    init_add=0, start_add=0, const_add=0, do_plot=false, nderivs=0, difforder=0, clearfig=true, fignum=1,
    dUdt_mag_only=false, sigma=0, g_leak=1, U_rest=0, theta=0, beta=1, 
    warn_if_unused_params=false, other_unused_params...)

    
    """
    o = g(z)    squashing tanh function, running from 0 to 1, is equal to 0.5 when input is 0.
    """
    function g(z)
        return 0.5*tanh.(z)+0.5
    end
    
    if warn_if_unused_params && length(other_unused_params)>0
        @printf("\n\n=== forwardModel warning, had unused params ")
        for k in keys(Dict(other_unused_params))
            @printf("%s, ", k)
        end
    end
    
    if length(size(opto_times))==1
        opto_times = reshape(opto_times, 1, 2)
    end
    
    my_input = ForwardDiffZeros(size(input,1), size(input,2), nderivs=nderivs, difforder=difforder)
    for i=1:prod(size(input)); my_input[i] = input[i]; end
    input = my_input;
    
    nunits = length(startU)
    if size(startU,2) > size(startU,1)
        error("startU must be a column vector")
    end
    
    # --- formatting input ---
    if ~(typeof(input)<:Array) || prod(size(input))==1  # was a scalar
        input = input[1]*(1+ForwardDiffZeros(nunits, nsteps, nderivs=nderivs, difforder=difforder))
    elseif length(input)==0 # was the empty matrix
        input = ForwardDiffZeros(nunits, nsteps, nderivs=nderivs, difforder=difforder)
    elseif size(input,2)==1     # was a column vector
        input = input*(1+ForwardDiffZeros(1, nsteps, nderivs=nderivs, difforder=difforder))
    end    
    # --- formatting noise ---
    if ~(typeof(noise)<:Array) || prod(size(noise))==1  # was a scalar
        noise = noise*(1+ForwardDiffZeros(nunits, nsteps, nderivs=nderivs, difforder=difforder))
    elseif length(noise)==0 # was the empty matrix
        noise = ForwardDiffZeros(nunits, nsteps, nderivs=nderivs, difforder=difforder)
    elseif size(noise,2)==1     # was a column vector
        noise = noise*(1+ForwardDiffZeros(1, nsteps, nderivs=nderivs, difforder=difforder))
    end    
    # --- formatting opto fraction ---
    if typeof(opto_strength)<:Array
        if size(opto_strength,1) != nunits || size(opto_strength,2) != nsteps
            error("opto_strength must be either a scalar or an nunits-by-nsteps matrix")
        end
        opto_matrix = opto_strength
    else # We assume that if opto_strength is not an Array, then it is a scalar
        opto_matrix = ForwardDiffZeros(nunits, nsteps, nderivs=nderivs, difforder=difforder) + 1
        time_axis = dt*(0:nsteps-1)
        for i=1:size(opto_times,1)
            opto_matrix[opto_units, (opto_times[i,1] .<= time_axis) & (time_axis .<= opto_times[i,2])] = opto_strength
        end
    end
    
    U = ForwardDiffZeros(nunits, nsteps, nderivs=nderivs, difforder=difforder)
    V = ForwardDiffZeros(nunits, nsteps, nderivs=nderivs, difforder=difforder)
    
    if ~(typeof(W)<:Array); W = [W]; end

    W     = reshape(W, nunits, nunits)
    U     = reshape(U, nunits, nsteps)
    V     = reshape(V, nunits, nsteps)
    input = reshape(input, nunits, nsteps)
    noise = reshape(noise, nunits, nsteps)

    input[:,1] += init_add
    input      += const_add

    #@printf("size(U) is (%d,%d), and size(startU) is (%d,%d) and size(noise) is (%d,%d)", 
    #    size(U,1), size(U,2), size(startU,1), size(startU,2), size(noise,1), size(noise,2))
    # @printf("U[1]=%g, noise[1]=%g\n", startU, noise[1])
    U[:,1] = startU + noise[:,1] + start_add; # @printf("Resulting U=%g\n", U[1])
    V[:,1] = g((U[:,1]-theta)/beta); 
#    @printf("U[1U[1,1])
    V[:,1] .*= opto_matrix[:,1]
    
    for i=2:nsteps
        dUdt = g_leak*(U_rest -U[:,i-1]) + W*V[:,i-1] + input[:,i-1]
        if dUdt_mag_only; return sum(dUdt.*dUdt); end;
        # @printf("dUdt=%g\n", dUdt[1])
        # @printf("i=%g\n", i)
        # @printf("noise[2]=%g\n", noise[2])
        U[:,i] = U[:,i-1] + (dt/tau)*dUdt + noise[:,i] + sigma*sqrt(dt)*randn(size(U,1),1)
        # @printf("Resulting U[2]=%g\n", U[2])
        V[:,i] = g((U[:,i]-theta)/beta)
        V[:,i] .*= opto_matrix[:,i]
        # @printf("Resulting V[2]=%g\n", V[2])
    end

    if do_plot
        figure(fignum)
        if length(startU)==1
            if clearfig; clf(); end;
            t = (0:nsteps-1)*dt
            plot(t, V[1,:], "b-")
            plot(t[1], V[1,1], "g.")
            plot(t[end], V[1,end], "r.")
            xlabel("t"); ylabel("V1"); ylim([-0.01, 1.01])
        elseif length(startU)>=2
            if clearfig; clf(); end;
            plot(V[1,:], V[2,:], "b-")
            plot(V[1,1], V[2,1], "g.")
            plot(V[1,end], V[2,end], "r.")
            xlabel("V1"); ylabel("V2"); 
            xlim([-0.01, 1.01]); ylim([-0.01, 1.01])
        end
    end

    return U[:,end], V[:,end], U, V, (0:nsteps-1)*dt
end


"""
backwardsModel(endU; dt=0.01, tau=0.1, nsteps=100, input=[0],noise=[],  W=[0 -5;-5 0], 
    do_plot=false, nderivs=0, difforder=0, clearfig=true, fignum=1, tol=1e-15, start_eta=10)

Runs a tanh() style-network BACKWARDS in time, given its ending point, by making a backwards
guess at each timepoint and then using Hessian minimization to find the backwards vector that correctly
leads to the current timestep value.  Uses forwardModel() . The forwards equations are:

    tau dU/dt = -U + W*V + I
    V = 0.5*tanh(U)+ 0.5

**PARAMETERS:**

endU     A column vector, nunits-by-1, indicating the values of U at time=end


**OPTIONAL PARAMETERS:**

dt      Scalar, timestep size

tau     Scalar, in seconds

nsteps  Number of timesteps to run, including time=0.

input   Either an nunits-by-1 vector, in which case inputs to each unit are constant
        across time, or a matrix, nunits-by-nsteps, indicating input for each unit at each timepoint.

W       Weight matrix, nunits-by-nunits

do_plot   Default false, if true, plots V of up to the first two dimensions

tol       Tolerance in the minimization procedure for finding each backwards timestep. Passed on
          to trust_region_Hessian_minimization()

start_eta   Passed on to trust_region_Hessian_minimization()

fignum     Figure number on which to plot

clrearfig  If true, the figure is first cleared, otherwise any plot ois overlaid

nderivs, difforder     Required for making sure function can create its own arrays and 
                       still be differentiated



** RETURNS:**

Ustart Vstart   nunits-by-1 vectors representing the starting values of U and V that were found.
U, V            nunits-by-nsteps matrices containing the full trajectories
costs           1-by-nsteps vector with the final cost from the minimization procedure for each
                timestep. This is the squared difference between the U[t+1] produced by the U[t] 
                guess and the actual U[t+1]

"""
function backwardsModel(endU; nsteps=100, start_eta=10, tol=1e-15, maxiter=400, 
    do_plot=false, init_add=0, start_add=0, dt=0.01, 
    input=[], noise=[], nderivs=0, difforder=0, clearfig=false, fignum=1, params...)    

    """
    o = g(z)    squashing tanh function, running from 0 to 1, is equal to 0.5 when input is 0.
    """
    function g(z)
        return 0.5*tanh.(z)+0.5
    end
    
    nunits = length(endU)

    # --- formatting input ---
    if ~(typeof(input)<:Array) || prod(size(input))==1  # was a scalar
        input = input[1]*(1+ForwardDiffZeros(nunits, nsteps, nderivs=nderivs, difforder=difforder))
    elseif length(input)==0 # was the empty matrix
        input = ForwardDiffZeros(nunits, nsteps, nderivs=nderivs, difforder=difforder)
    elseif size(input,2)==1     # was a column vector
        input = input*(1+ForwardDiffZeros(1, nsteps, nderivs=nderivs, difforder=difforder))
    end    
    # --- formatting noise ---
    if ~(typeof(noise)<:Array)  # was a scalar
        noise = noise*(1+ForwardDiffZeros(nunits, nsteps, nderivs=nderivs, difforder=difforder))
    elseif length(noise)==0 # was the empty matrix
        noise = ForwardDiffZeros(nunits, nsteps, nderivs=nderivs, difforder=difforder)
    elseif size(noise,2)==1     # was a column vector
        noise = noise*(1+ForwardDiffZeros(1, nsteps, nderivs=nderivs, difforder=difforder))
    end    
    
    function J(U1, U2; nderivs=0, difforder=0, noise=[], inputs=[], pars...)
        U2hat = forwardModel(U1; nsteps=2, noise=noise, input=input, nderivs=nderivs, difforder=difforder, pars...)[1]
        U2hat = U2hat
        DU = U2hat - U2
    
        return sum(DU.*DU)
    end
    
    if length(noise)==0
        noise = ForwardDiffZeros(nunits, nsteps, nderivs=nderivs, difforder=difforder)
    end

    U = ForwardDiffZeros(nunits, nsteps, nderivs=nderivs, difforder=difforder)
    U = reshape(U, nunits, nsteps)
    costs = ForwardDiffZeros(nsteps, 1, nderivs=nderivs, difforder=difforder)    
    
    U[:,end] = endU
    for i=(nsteps-1):-1:1
        if i==1
            my_init_add = init_add
            my_start_add = start_add
        else
            my_init_add = 0
            my_start_add = 0
        end
                
        U[:,i], costs[i] = trust_region_Hessian_minimization(U[:,i+1], 
            (x) -> J(x, U[:,i+1]; nderivs=length(endU), difforder=2, 
            input=input[:,i:i+1], noise = noise[:,i:i+1], 
            init_add=my_init_add, start_add=my_start_add, params...); 
            verbose=false, start_eta=start_eta, tol=tol, maxiter=maxiter)
        if i>1; U[:,i] += noise[:,i]; end
    end
    
    
    V = g(U)  # REALLY???? HOW ABOUT THETA AND BETA?
    
    if do_plot
        figure(fignum)   
        if typeof(params)<:Array; params = Dict(params); end;
        if haskey(params, :dt);     dt     = params[:dt];     end
        if haskey(params, :nsteps); nsteps = params[:nsteps]; end
        if length(endU)==1
            if clearfig; clf(); end;
            t = (0:nsteps-1)*dt
            plot(t, V[1,:], "m-")
            plot(t[1], V[1,1], "go")
            plot(t[end], V[1,end], "ro")            
            ylim([-0.01, 1.01])
        elseif length(endU)>=2
            if clearfig; clf(); end;            
            plot(V[1,:], V[2,:], "m-")
            plot(V[1,1], V[2,1], "go")
            plot(V[1,end], V[2,end], "ro")
            xlim([-0.01, 1.01]); ylim([-0.01, 1.01])
        end
    end
    
    return U[:,1], V[:,1], U, V, costs
end

### Testing forward and backwards models with only 1 dimension

In [None]:
# First run a simple one-dimensional model (only one unit).  
#
# We've included the start_add parameter here, but the truth is that is a bad parameter
# and is best not used: the reson is that it does not scale with dt.
pygui(true); figure(1); clf();
params = Dict(:noise => [0.1], :W => [-2], :nsteps=>10, :start_add=>-1.9)
Uend = forwardModel([1.1]; do_plot=true, params...)[1]

# Now an example of running the backwards model.
# Note that because of the start_add parameter, the plot and the backwards model
# look a little different: the backwards model correctly returns the starting value of U
# while the forwards plot uses the initial value of V(U) *after( the start_add)
#
Ustart = backwardsModel(Uend; do_plot=true, tol=1e-30, params...)[1]
@printf("Ustart came back as %g\n", Ustart[1])



### Testing forward and backwards models now with 2 dimensions

In [None]:
# If instead of letting the noise be generated internally in forwardModel(), 
# we pass a snapshot of it to forwardModel(), that same snapshot can then be passed to backwardsModel()
# which can then know what the noise sample at each timestep is, and therefore can take it into account.

nsteps=50
params = Dict(:noise =>0.03*randn(2,nsteps) + [0.1,0]*ones(1,nsteps), :W => [0 -5; -5 0], :nsteps=>nsteps)

Uend, Vend, U, V              = forwardModel([0.1,0.1]; do_plot=true, params...);
Ustart, Vstart, bU, bV, costs = backwardsModel(Uend; do_plot=true, tol=1e-30, params...)

@printf("Ustart came back as : "); print_vector_g(Ustart); print("\n")

### Testing opto in forwardsModel()

In [None]:
#
#  Specifying opto_strength as a scalar
#

pygui(true)

nsteps=50
params = Dict(:sigma=>0.02, :opto_strength=>0.1, :opto_units=>1, :opto_times=>[0.05 0.15 ; 0.4 0.45],
:W => [0 -5; -5 0], :nsteps=>nsteps)

Uend, Vend, U, V, t = forwardModel([0.1,0.1]; input=[0.1,0], do_plot=true, fignum=1, params...);

figure(2); clf();
subplot(2,1,1)
plot(t, V[1,:], t, V[2,:])
ylabel("V")
remove_xtick_labels
subplot(2,1,2)
plot(t, U[1,:], t, U[2,:])
ylabel("U")
xlabel("t")


# And now a derivative w.r.t. opto_strength
func = (;pars...) -> forwardModel([0.1, 0.1]; input=[0.1,0], do_plot=false, merge(params, Dict(pars))...)[1][1]

val, grad, hess = keyword_vgh(func, ["opto_strength"], [0.09])

In [None]:
#
#  Specifying opto_strength as a full nunits-by-nsteps matrix
#


nsteps=50
opto_strength = ones(2, nsteps)
opto_strength[1, 10:20] = 0.2
params = Dict(:sigma=>0.2, :opto_strength=>opto_strength, :opto_units=>1:2, :opto_times=>[0.05 0.15 ; 0.4 0.45],
:W => [0 -5; -5 0], :nsteps=>nsteps)

Uend, Vend, U, V, t = forwardModel([0.1,0.1]; input=[0.1,0], do_plot=true, fignum=1, params...);


# Exploring dt-dependence of gradients and hessian

If we're doing things correctly, and $dt$ is small enough that we're starting to approximate the continuous-time solution, we should find that the output of our network does not depend very much on the choice of timestep $dt$. Correspondingly, gradients of the output with respect to network parameters should also be relatively $dt$-independent.

In this example, we'll work with a two-dimensional mutual-inhibition network. We'll have one parameter, $W$, that represents the weight of the connection between the two units. Note, however that `forwardModel()` takes in a full connection matrix $W$, since `forwardModel()` makes no assumptions about any structure in that connection matrix. So what we do is wrap `forwardModel()` in a local function called `forward()` that transforms the scalar $W$ into a full 2-by-2 matrix $W$ and then calls `forwardModel()`.

In [None]:

# If you want to freeze the noise, provide a seed to the random number generator, e.g.: srand(111)
# If you want a new random seed every time you run, but also want to preserve the seed so as to be able to re-run
# the exact same noise again, you can use the local time, e.g. here in tenths of milliseconds:
sr = Int64(round(10000*time()))
srand(sr)
startU=randn(100,2)-3

sigma = 0  # This will be the default; below we set it to something else as a parameter


dt = 0.02
t = 0:dt:1
tau = 0.1
nsteps = length(t)
t = t[1:nsteps]

W = -4
noise = 0
input = 0
sigma = 0


model_params = Dict(:dt=>dt, :tau=>tau, :W=>[0 W; W 0], :nsteps=>nsteps, 
:noise=>noise, :input=>input, :sigma=>sigma, :const_add=>0, :init_add=>0)

# wrapper function that will take our scalar W, indicating weight between the two units, and 
# turn that into a 2-by-2 weight matrix:
forward = (startU; pars...) -> begin
    pars = Dict(pars)
    if haskey(pars, :W); 
        W=pars[:W];   # mess with it only if it is not already a matrix:
        if length(W)==1; pars=make_dict(["W"], [[0 W;W 0]], pars); end;
    end;     
    forwardModel(startU; pars...)
end




args = ["W", "const_add", ["start_add" 2], "sigma"]
params = [-4.01, 0.5, 0.2, -0.2, 0.01]

# --- first with dt = 0.02
figure(1); clf();
value1, grad1, hess1 = keyword_vgh((;pars...)->forward([-0.2, 0.3]; do_plot=true, merge(model_params, Dict(pars))...)[1][1], args, params)
title(@sprintf("Running with dt=%g", dt))


# --- now with dt = 0.005

dt = 0.005
t = 0:dt:1
tau = 0.1
nsteps = length(t)
t = t[1:nsteps]

model_params = Dict(:dt=>dt, :tau=>tau, :W=>[0 W; W 0], :nsteps=>nsteps, 
:noise=>noise, :input=>input, :sigma=>sigma, :const_add=>0, :init_add=>0)

figure(2); clf();
value2, grad2, hess2 = keyword_vgh((;pars...)->forward([-0.2, 0.3]; do_plot=true, fignum=2, merge(model_params, Dict(pars))...)[1][1], args, params)
title(@sprintf("Running with dt=%g", dt))


# --- and again with dt = 0.02 but different instantiation of the noise at each timestep

dt = 0.02
t = 0:dt:1
tau = 0.1
nsteps = length(t)
t = t[1:nsteps]

model_params = Dict(:dt=>dt, :tau=>tau, :W=>[0 W; W 0], :nsteps=>nsteps, 
:noise=>noise, :input=>input, :sigma=>sigma, :const_add=>0, :init_add=>0)

figure(3); clf();
value3, grad3, hess3 = keyword_vgh((;pars...)->forward([-0.2, 0.3]; do_plot=true, fignum=3, merge(model_params, Dict(pars))...)[1][1], args, params)
title(@sprintf("Running again with dt=%g", dt))

# As you will see below, the gradient values are all pretty stable across runs except for the gradient with respect to
# sigma, but that is largely because of different noise instantiations, the two runs at dt=0.02 differ by about
# as much as their difference w.r.t. the run at dt=0.005
["first dt=0.02 run" grad1[:]' ; "dt=0.005 run" grad2[:]'; "second dt=0.02 run" grad3[:]']