In [1]:
import IJulia

# The julia kernel has built in support for Revise.jl, so this is the 
# recommended approach for long-running sessions:
# https://github.com/JuliaLang/IJulia.jl/blob/9b10fa9b879574bbf720f5285029e07758e50a5e/src/kernel.jl#L46-L51

# Users should enable revise within .julia/config/startup_ijulia.jl:
# https://timholy.github.io/Revise.jl/stable/config/#Using-Revise-automatically-within-Jupyter/IJulia-1

# clear console history
IJulia.clear_history()

fig_width = 5.5
fig_height = 3.5
fig_format = :pdf
fig_dpi = 300

# no retina format type, use svg for high quality type/marks
if fig_format == :retina
  fig_format = :svg
elseif fig_format == :pdf
  fig_dpi = 96
  # Enable PDF support for IJulia
  IJulia.register_mime(MIME("application/pdf"))
end

# convert inches to pixels
fig_width = fig_width * fig_dpi
fig_height = fig_height * fig_dpi

# Intialize Plots w/ default fig width/height
try
  import Plots

  # Plots.jl doesn't support PDF output for versions < 1.28.1
  # so use png (if the DPI remains the default of 300 then set to 96)
  if (Plots._current_plots_version < v"1.28.1") & (fig_format == :pdf)
    Plots.gr(size=(fig_width, fig_height), fmt = :png, dpi = fig_dpi)
  else
    Plots.gr(size=(fig_width, fig_height), fmt = fig_format, dpi = fig_dpi)
  end
catch e
  # @warn "Plots init" exception=(e, catch_backtrace())
end

# Initialize CairoMakie with default fig width/height
try
  import CairoMakie
  
  CairoMakie.activate!(type = string(fig_format))
  CairoMakie.update_theme!(resolution=(fig_width, fig_height))
catch e
    # @warn "CairoMakie init" exception=(e, catch_backtrace())
end
  
# Set run_path if specified
try
  run_path = raw"/home/kyle/Dropbox (GaTech)/quarto/proposal"
  if !isempty(run_path)
    cd(run_path)
  end
catch e
  @warn "Run path init:" exception=(e, catch_backtrace())
end


# emulate old Pkg.installed beahvior, see
# https://discourse.julialang.org/t/how-to-use-pkg-dependencies-instead-of-pkg-installed/36416/9
import Pkg
function isinstalled(pkg::String)
  any(x -> x.name == pkg && x.is_direct_dep, values(Pkg.dependencies()))
end

# ojs_define
if isinstalled("JSON") && isinstalled("DataFrames")
  import JSON, DataFrames
  global function ojs_define(; kwargs...)
    convert(x) = x
    convert(x::DataFrames.AbstractDataFrame) = Tables.rows(x)
    content = Dict("contents" => [Dict("name" => k, "value" => convert(v)) for (k, v) in kwargs])
    tag = "<script type='ojs-define'>$(JSON.json(content))</script>"
    IJulia.display(MIME("text/html"), tag)
  end
elseif isinstalled("JSON")
  import JSON
  global function ojs_define(; kwargs...)
    content = Dict("contents" => [Dict("name" => k, "value" => v) for (k, v) in kwargs])
    tag = "<script type='ojs-define'>$(JSON.json(content))</script>"
    IJulia.display(MIME("text/html"), tag)
  end
else
  global function ojs_define(; kwargs...)
    @warn "JSON package not available. Please install the JSON.jl package to use ojs_define."
  end
end


# don't return kernel dependencies (b/c Revise should take care of dependencies)
nothing


In [2]:
using JuMP
using Plots
using LinearAlgebra

#defining all parameters for model and probem
A = [1 -6.66e-13 -2.03e-9 -4.14e-6;
        9.83e-4 1 -4.09e-8 -8.32e-5;
        4.83e-7 9.83e-4 1 -5.34e-4;
        1.58e-10 4.83e-7 9.83e-4 .9994;
]

B = [9.83e-4 4.83e-7 1.58e-10 3.89e-14]'
B1 = B
# fake opsin
Binh = -B .* (1 .+ randn(4, 1)./5)
B2 = [B Binh]

C = [-.0096 .0135 .005 -.0095]

y2z(y) = exp(61.4*y - 5.468)
z2y(z) = (log(z) + 5.468)/61.4

t_step = 0.001 #1 milisecond
pred = 25 #prediction horizon
ctr = 5 #control horizon
sample = 200 #number of steps between sampling points for control

zD = 0.2
yD = z2y(zD)
uD = inv(C*inv(I - A)*B) * yD
#xD = inv(I - A) * B * uD

Q = C'*C
R = 1e-5*I

T = 13000
Tpred = pred*sample
# add in our variable reference
#structured to take a list of firing rate values, convert them to x vectors, then pad with zeros if necessary
##zref = [i < 3000 ? 0.1 : 0.15 for i in 1:T] #trial firing rate reference
# trial firing rate reference
zref = .1 .+ .08*sin.(range(start=0, stop=2*pi, length=Int(1.5*Tpred)));

#zref = [0.2 for i in 1:(T)] #trial firing rate reference

In [3]:
using MatrixEquations: ared

function lqr(T, x0, zref, nu=1)
    if nu == 1
        B = B1
    elseif nu == 2
        B = B2
    end

    P, _, _ = ared(A, B1, R, Q)
    K = inv(R + B1'*P*B1) * B1'*P*A

    zs = zeros(1, T)
    us = zeros(nu, T)
    x = x0
    zs[:, 1] = y2z.(C*x)
    for t in 1:T
        zref_current = t > length(zref) ? zref[end] : zref[t]
        yD = z2y(zref_current)
        uD = inv(C*inv(I - A)*B1) * yD
        xD = inv(I - A) * B1 * uD

        u = (-K*(x - xD) + uD)
        if nu == 2
            # use other opsin if negative
            u = u[1] < 0 ? [0, -u[1]] : [u[1], 0]
        elseif nu == 1 && u[1] < 0
            u = [0]
        end


        us[:, t] = u
        x = A*x + B*u
        # println("-->>🦄🌈🤡🥳🎉🎊🪅<<--")
        y = C*x
        zs[:, t] = y2z.(y)
    end
    return zs, us
end

lqr1res = lqr(T, zeros(4), zref, 1);
lqr2res = lqr(T, zeros(4), zref, 2);
# plot(plotctrl(lqr1res...), plotctrl(lqr2res...))

In [4]:
using OSQP

function mpc(steps, x0, zref; nu=1, u_clamp=nothing, sample=250)
    if nu == 1
        B = B1
    elseif nu == 2
        B = B2
    end

    Tpred = pred*sample
    Tall = steps*sample
    if length(zref) < Tall + Tpred
        zrefpad = cat(zref, fill(zref[end], Tall + Tpred - length(zref)), dims=1)
    end

    neuron = Model(OSQP.Optimizer)
    set_silent(neuron)

    #Define state variables
    @variables(neuron, begin
        x[i=1:4, t=1:Tpred]
        0 ≤ u[1:nu, 1:(ctr+1)]
        yD[i=1:1, t=1:Tpred]
        # xD[i=1:4, t=1:Tpred], (start = 0)
    end)

    @expressions(
        neuron,
        begin
            y, C*x
            # x_error[t=1:Tpred], x[:, t] - xD[:, t]
            # x_cost[t=1:Tpred], x_error[t]'*Q*x_error[t]
            y_error[t=1:Tpred], y[:, t] - yD[:, t]
            y_cost[t=1:Tpred], y_error'[t]*y_error[t]
            # sampled_x_cost[t=1:pred], x_cost[t*sample]
            sampled_y_cost[t=1:pred], y_cost[t*sample]
            u_cost[t=1:ctr+1], u[t]'*R*u[t]
        end
    )

    #fix first sample steps
    @constraint(neuron, x[:, 2:(sample)] .== A*x[:, 1:sample-1] + B*u[:, 1])

    #fix each sample period
    for i in 1:(ctr-1)
        @constraint(neuron, x[:, (sample*i+1):(sample*i+sample)] .== A*x[:, (sample*i):(sample*i+sample-1)] + B*u[:, (i+1)])
    end

    #fix rest of inputs 
    @constraint(neuron, x[:, (ctr*sample+1):(Tpred)] .== A*x[:, (ctr*sample):(Tpred-1)] + B*u[:, (ctr+1)])

    yDall = z2y.(zrefpad)

    # x_cost[t] returns a 1x1 matrix, which we need to index to get the value out
    # J = @objective(neuron, Min, sum(sampled_x_cost[t] for t in 2:(pred)) + sum(u_cost[t] for t in 1:ctr+1))
    J = @objective(neuron, Min, sum(sampled_y_cost[t] for t in 2:(pred)) + sum(u_cost[t] for t in 1:ctr+1))

    # if nu == 2 
    #     B = [B -B]
    # end
    zs = zeros(1, steps*sample)
    us = zeros(nu, steps*sample)
    x_current = x0
    tfine = 1
    for t in 1:steps
        fix.(x[:, 1], x_current; force=true)
        if u_clamp != nothing
            fix.(u[:, 1], u_clamp; force=true)
        end
        #now need to update the reference by "shifting it" one sample size forward and padding with end value
        fix.(yD[:], yDall[tfine:tfine+Tpred-1], force=true)

        optimize!(neuron)
        zs[1, tfine] = y2z.(value(y[1]))
        us[:, tfine] = value.(u[:, 1])
        # append!(zs , y2z.(value(y[1])))
        # append!(us, value(u[:, 1]))
        x_current = value.(x[:, 2])
        # println(x_current)
        # now effectively apply optimal first u for 250 more steps
        const_u = value.(u[:, 1]);
        tfine += 1
        for i in 1:(sample-1)
            x_current = A*x_current + B*const_u
            # append!(zs, y2z.((C*x_current)[1]))
            zs[1, tfine] = y2z.((C*x_current)[1])
            us[:, tfine] = const_u
            tfine += 1
        end
    end
    #println(solution_summary(neuron))
    return zs, us
end

steps = T ÷ sample
mpc2res = mpc(steps, zeros(4), zref, nu=2, u_clamp=nothing, sample=sample);

In [5]:
mpc1res = mpc(steps, zeros(4), zref, nu=1, u_clamp=nothing, sample=sample);

In [6]:
#| label: fig-mpc-sim
#| fig-cap: 'Simulated control of an linear dynamical system with 1- and 2-input control, using LQR and MPC controllers. The top panel of each contains the reference and the actual firing rate, in spikes/second. The bottom contains the light intensity, in terms of mW/mm^2^, where blue represents light for an excitatory opsin (such as ChR2) and red-orange that for an inhibitory opsin (such as Jaws).'
using LaTeXStrings
function plotctrl(zs, us; title=nothing, plotargs...)
    nu = size(us, 1)
    last = zref[end][1]
    if (length(zref)) < T
        for i in (length(zref)):(T - 1)
            append!(zref, last)
        end
    end

    # print(length(zref), "\n", length(us), "\n", length(zs))
    # print("zref: $(length(zref))")

    time = (1:T) ./ 1000
    zs_plot = plot(time, [zref zs'], label=[L"r" L"z"],
        color=["green" "black"], lw=2)
    plot!(legend=false)
    if title != nothing
        plot!(title=title)
    end

    # if nu == 1
    #     ucolor = :lightskyblue
    # elseif nu == 2
    #     ucolor = [:lightskyblue :orangered]
    # end
    # ucolor = "lightskyblue" if nu == 1 else [:lightskyblue :redorange] end
    u_plot = plot(time, us[1, :], color="#72b5f2", lw=2, xlabel="time (s)", legend=false, label=L"u_{exc}")
    if nu == 2
        plot!(time, us[2, :], color=:orangered, lw=2, label=L"u_{inh}")
    end

    return plot(zs_plot, u_plot; layout=(2, 1), link=:x, plotargs...)
end

lqr1plot = plotctrl(lqr1res...; title="1-input LQR")
lqr2plot = plotctrl(lqr2res...; title="2-input LQR", legend=:topright, legendcolumns=2)
mpc1plot = plotctrl(mpc1res...; title="1-input MPC")
mpc2plot = plotctrl(mpc2res...; title="2-input MPC")

plot(lqr1plot, lqr2plot, mpc1plot, mpc2plot, layout=4)