From ee370b6d782d3e835de2fd8628b227de28da79e5 Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Sun, 14 Apr 2024 03:54:47 +0200 Subject: [PATCH] added save_history and update_states --- src/ActionModels.jl | 4 +++- src/create_agent/init_agent.jl | 6 +++++- src/fitting/fit_model.jl | 16 ++++++++------- src/fitting/prefit_checks.jl | 3 +-- .../binary_rescorla_wagner_softmax.jl | 14 ++++++------- .../continuous_rescorla_wagner_gaussian.jl | 10 ++++------ src/structs.jl | 5 +++-- src/utils/set_save_history.jl | 14 +++++++++++++ src/utils/update_states.jl | 20 +++++++++++++++++++ 9 files changed, 66 insertions(+), 26 deletions(-) create mode 100644 src/utils/set_save_history.jl create mode 100644 src/utils/update_states.jl diff --git a/src/ActionModels.jl b/src/ActionModels.jl index f1fe79d..aa3db6b 100644 --- a/src/ActionModels.jl +++ b/src/ActionModels.jl @@ -10,7 +10,7 @@ export create_agent_model, fit_model export plot_parameter_distribution, plot_predictive_simulation, plot_trajectory, plot_trajectory! export get_history, get_states, get_parameters, set_parameters!, reset!, give_inputs!, single_input! -export get_posteriors +export get_posteriors, update_states!, set_save_history! #Load premade agents function __init__() @@ -47,6 +47,8 @@ include("utils/set_parameters.jl") include("utils/warn_premade_defaults.jl") include("utils/get_posteriors.jl") include("utils/pretty_printing.jl") +include("utils/update_states.jl") +include("utils/set_save_history.jl") #Premade agents include("premade_models/binary_rescorla_wagner_softmax.jl") diff --git a/src/create_agent/init_agent.jl b/src/create_agent/init_agent.jl index 809195b..bce1b88 100644 --- a/src/create_agent/init_agent.jl +++ b/src/create_agent/init_agent.jl @@ -83,9 +83,10 @@ function init_agent( action_model::Function; substruct::Any = nothing, parameters::Dict = Dict(), + shared_parameters::Dict = Dict(), states::Union{Dict,Vector} = Dict(), settings::Dict = Dict(), - shared_parameters::Dict = Dict(), + save_history::Bool = true, ) ##Create action model struct @@ -96,6 +97,7 @@ function init_agent( initial_state_parameters = Dict(), states = Dict(), settings = settings, + save_history = save_history, ) @@ -199,6 +201,7 @@ function init_agent( parameters::Dict = Dict(), states::Dict = Dict(), settings::Dict = Dict(), + save_history::Bool = true, ) #If a setting called action_models has been specified manually @@ -221,6 +224,7 @@ function init_agent( parameters = parameters, states = states, settings = settings, + save_history = save_history, ) return agent diff --git a/src/fitting/fit_model.jl b/src/fitting/fit_model.jl index 1be815e..e41b167 100644 --- a/src/fitting/fit_model.jl +++ b/src/fitting/fit_model.jl @@ -38,7 +38,7 @@ fit_model(agent, priors, inputs, actions, n_chains = 1, n_iterations = 10) ``` """ function fit_model( - agent::Agent, + original_agent::Agent, priors::Dict, data::DataFrame; independent_group_cols::Vector = [], @@ -66,11 +66,8 @@ function fit_model( input_cols = Symbol.(input_cols) action_cols = Symbol.(action_cols) - ## Store old parameters for resetting the agent later ## - old_parameters = get_parameters(agent) - - ## Set fixed parameters to agent ## - set_parameters!(agent, fixed_parameters) + ## Copy the agent to avoid changing the original ## + agent = deepcopy(original_agent) ## Run checks ## prefit_checks( @@ -82,11 +79,16 @@ function fit_model( input_cols = input_cols, action_cols = action_cols, fixed_parameters = fixed_parameters, - old_parameters = old_parameters, n_cores = n_cores, verbose = verbose, ) + ## Set save_history to false ## + set_save_history!(agent, false) + + ## Set fixed parameters to agent ## + set_parameters!(agent, fixed_parameters) + ## Set logger ## #If sample rejection warnings are to be shown if show_sample_rejections diff --git a/src/fitting/prefit_checks.jl b/src/fitting/prefit_checks.jl index 052a084..08c6644 100644 --- a/src/fitting/prefit_checks.jl +++ b/src/fitting/prefit_checks.jl @@ -10,7 +10,6 @@ function prefit_checks(; input_cols, action_cols, fixed_parameters, - old_parameters, n_cores, verbose = true, ) @@ -20,7 +19,7 @@ function prefit_checks(; #If there are any of the agent's parameters which have not been set in the fixed or sampled parameters if any( key -> !(key in keys(priors)) && !(key in keys(fixed_parameters)), - keys(old_parameters), + keys(agent.parameters), ) @warn "the agent has parameters which are not specified in the fixed or sampled parameters. The agent's current parameter values are used as fixed parameters" end diff --git a/src/premade_models/binary_rescorla_wagner_softmax.jl b/src/premade_models/binary_rescorla_wagner_softmax.jl index 9f0fa84..c2951c9 100644 --- a/src/premade_models/binary_rescorla_wagner_softmax.jl +++ b/src/premade_models/binary_rescorla_wagner_softmax.jl @@ -28,13 +28,12 @@ function binary_rescorla_wagner_softmax(agent::Agent, input::Union{Bool,Integer} action_distribution = Distributions.Bernoulli(action_probability) #Update states - agent.states["value"] = new_value - agent.states["value_probability"] = 1 / (1 + exp(-new_value)) - agent.states["action_probability"] = action_probability - #Add to history - push!(agent.history["value"], new_value) - push!(agent.history["value_probability"], 1 / (1 + exp(-new_value))) - push!(agent.history["action_probability"], action_probability) + update_states!(agent, Dict( + "value" => new_value, + "value_probability" => 1 / (1 + exp(-new_value)), + "action_probability" => action_probability, + "input" => input, + )) return action_distribution end @@ -77,6 +76,7 @@ function premade_binary_rescorla_wagner_softmax(config::Dict) "value" => missing, "value_probability" => missing, "action_probability" => missing, + "input" => missing, ) settings = Dict() diff --git a/src/premade_models/continuous_rescorla_wagner_gaussian.jl b/src/premade_models/continuous_rescorla_wagner_gaussian.jl index f418f4f..5ba2274 100644 --- a/src/premade_models/continuous_rescorla_wagner_gaussian.jl +++ b/src/premade_models/continuous_rescorla_wagner_gaussian.jl @@ -19,12 +19,10 @@ function continuous_rescorla_wagner_gaussian(agent::Agent, input::Real) action_distribution = Distributions.Normal(new_value, action_noise) ##Update the states and save them to agent's history - - agent.states["value"] = new_value - agent.states["input"] = input - - push!(agent.history["value"], new_value) - push!(agent.history["input"], input) + update_states!(agent, Dict( + "value" => new_value, + "input" => input, + )) ## return the action distribution to sample actions from return action_distribution diff --git a/src/structs.jl b/src/structs.jl index 705ada7..be91911 100644 --- a/src/structs.jl +++ b/src/structs.jl @@ -5,10 +5,11 @@ Base.@kwdef mutable struct Agent substruct::Any parameters::Dict = Dict() initial_state_parameters::Dict{String,Any} = Dict() - states::Dict{String,Any} = Dict("action" => missing) - settings::Dict{String,Any} = Dict() shared_parameters::Dict = Dict() + states::Dict{String,Any} = Dict("action" => missing) history::Dict{String,Vector{Any}} = Dict("action" => [missing]) + settings::Dict{String,Any} = Dict() + save_history::Bool = true end diff --git a/src/utils/set_save_history.jl b/src/utils/set_save_history.jl new file mode 100644 index 0000000..e870881 --- /dev/null +++ b/src/utils/set_save_history.jl @@ -0,0 +1,14 @@ +#Function for changing the save_history setting +function set_save_history!(agent::Agent, save_history::Bool) + + #Change it in the agent + agent.save_history = save_history + + #And in its substruct + set_save_history!(agent.substruct, save_history) +end + +#If there is an empty substruct, do nothing +function set_save_history!(substruct::Nothing, save_history::Bool) + return nothing +end \ No newline at end of file diff --git a/src/utils/update_states.jl b/src/utils/update_states.jl new file mode 100644 index 0000000..23a895c --- /dev/null +++ b/src/utils/update_states.jl @@ -0,0 +1,20 @@ +### Function which updates a single state, and saves it history +function update_states!(agent::Agent, state::String, value) + #Update state + agent.states[state] = value + + #Save to history + if agent.save_history + push!(agent.history[state], value) + end +end + + +### Function which updates a dictionary of states to their values +function update_states!(agent::Agent, states::Dict) + #For each state and value + for (state, value) in states + #Update it + update_states!(agent, state, value) + end +end \ No newline at end of file