In [None]:
cd("/Users/evanrussek/foraging/")

using CSV
using DataFrames
using DataFramesMeta
using CategoricalArrays
using Gadfly
using Statistics
using Distributions
using SpecialFunctions
using StatsFuns
using Optim
using ForwardDiff

include("sim_lag_functions.jl")
include("sim_learn_funcs.jl")

# start reward was 120,90,60

param_dict = Dict();
param_dict["harvest_cost"] = 1.;#.1 + 10*rand();
param_dict["travel_cost_easy"] = 2.;#param_dict["harvest_cost"] + 5*rand();
param_dict["travel_cost_hard"] = 8.#;param_dict["travel_cost_easy"] + 8*rand();
param_dict["r_hat_start_reward_weight"] = .2;#.01 + .4*rand()
param_dict["r_hat_start_easy_weight"] = 1;#.01 * 10*rand();
param_dict["harvest_lag_hat_start"] = 2.#.01 + rand()*5;#1.0; # don't fit this...
param_dict["harvest_bias"] = 0;#100;#-10 + rand()*20;
param_dict["choice_beta"] = 1; #.001 + rand()*5;
param_dict["lag_beta"] = 2.;#.001 + rand()*8.;
param_dict["lr_R_hat_pre"] = -2.7;#-4. + 4*rand();
param_dict["lr_harvest_lag_hat_pre"] = -2.;#-4 + 5*rand();
transform_lr(param_dict["lr_R_hat_pre"])

# show original params
param_dict

# simulate tasks and make plots
sim_df = sim_forage_learn(param_dict);
make_exit_plot(sim_df)
make_lag_plot(sim_df)


plot(sim_df, x = :time, y = :reward_obs, xgroup = :travel_key_cond,ygroup = :start_reward, color = :start_reward,
    Geom.subplot_grid(Geom.line))

plot(sim_df, x = :time, y = :R_hat, xgroup = :travel_key_cond,
    group = :trial_num, color = :start_reward, linestyle = :travel_key_cond,
    Geom.subplot_grid(Geom.line))


plot(sim_df, x = :time, y = :harvest_lag_hat, group = :trial_num, color = :start_reward,
 linestyle = :travel_key_cond, Geom.line)

plot(sim_df, x = :time, y = :threshold, group = :trial_num, color = :start_reward,
    linestyle = :travel_key_cond, Geom.line)


########## likelihood function
param_names = [];
param_vals = Float64[];
for (k,v) in param_dict
    #println(k,v)
    push!(param_names, k)
    push!(param_vals, v)
end
print(param_names) # check that this matches the order in the likelihood function...
include("sim_learn_funcs.jl")


# get start values for search
start_p = generate_start_vals((x) -> forage_learn_lik2(x,sim_df,"choice"))

# check that we can take the gradient at the first value...
#ForwardDiff.gradient(cost_fun,start_x)

# fit the simulated data.
a_both = optimize(
    (x) -> forage_learn_lik2(x,sim_df, "both"),start_p, LBFGS(),
    Optim.Options(allow_f_increases=true, iterations = 4000, show_trace = false),
    autodiff=:forward)

# fit the simulated data.
a_choice = optimize(
    (x) -> forage_learn_lik2(x,sim_df, "choice"),start_p, LBFGS(),
    Optim.Options(allow_f_increases=true, iterations = 4000, show_trace = false),
    autodiff=:forward)

# fit the simulated data.
a_lag = optimize(
    (x) -> forage_learn_lik2(x,sim_df, "choice"),start_p, LBFGS(),
    Optim.Options(allow_f_increases=true, iterations = 4000, show_trace = false),
    autodiff=:forward)


fit_df_both = make_recov_df(a_both,param_names,param_dict, (x) -> forage_learn_lik2(x,sim_df,"both"))
fit_df_choice = make_recov_df(a_choice,param_names,param_dict, (x) -> forage_learn_lik2(x,sim_df,"choice"))
fit_df_lag = make_recov_df(a_lag,param_names,param_dict, (x) -> forage_learn_lik2(x,sim_df,"lag"))














# o use python's scipy.optimize...
#so = pyimport("scipy.optimize")
#@time a = so.minimize(cost_fun, start_x, method="L-BFGS-B", jac = (x->ForwardDiff.gradient(cost_fun,x)))

p_hat_dict = Dict()
for j in 1:length(param_names)
    p_hat_dict[param_names[j]] = p_hat[j]
end

# compare plots of orig_vals, recovered vals...
sim_df_rec = sim_forage_learn(p_hat_dict);
p_rec_choice = make_exit_plot(sim_df_rec)
p_orig_choice = make_exit_plot(sim_df)
vstack([p_orig_choice; p_rec_choice])

# the lag looks correct...
p_rec_lag = make_lag_plot(sim_df_rec)
p_orig_lag = make_lag_plot(sim_df)
vstack([p_orig_lag; p_rec_lag])


plot(sim_df_rec, x = :time, y = :R_hat, xgroup = :travel_key_cond,
    group = :trial_num, color = :start_reward, linestyle = :travel_key_cond,
    Geom.subplot_grid(Geom.line), Guide.title("Recovered Params"))

    plot(sim_df, x = :time, y = :R_hat, xgroup = :travel_key_cond,
        group = :trial_num, color = :start_reward, linestyle = :travel_key_cond,
        Geom.subplot_grid(Geom.line), Guide.title("Original Params"))