Skip to content

Commit

Permalink
added save_history and update_states
Browse files Browse the repository at this point in the history
  • Loading branch information
PTWaade committed Apr 14, 2024
1 parent 22a1bc8 commit ee370b6
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 26 deletions.
4 changes: 3 additions & 1 deletion src/ActionModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export create_agent_model, fit_model
export plot_parameter_distribution,
plot_predictive_simulation, plot_trajectory, plot_trajectory!
export get_history, get_states, get_parameters, set_parameters!, reset!, give_inputs!, single_input!
export get_posteriors
export get_posteriors, update_states!, set_save_history!

#Load premade agents
function __init__()
Expand Down Expand Up @@ -47,6 +47,8 @@ include("utils/set_parameters.jl")
include("utils/warn_premade_defaults.jl")
include("utils/get_posteriors.jl")
include("utils/pretty_printing.jl")
include("utils/update_states.jl")
include("utils/set_save_history.jl")

#Premade agents
include("premade_models/binary_rescorla_wagner_softmax.jl")
Expand Down
6 changes: 5 additions & 1 deletion src/create_agent/init_agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@ function init_agent(
action_model::Function;
substruct::Any = nothing,
parameters::Dict = Dict(),
shared_parameters::Dict = Dict(),
states::Union{Dict,Vector} = Dict(),
settings::Dict = Dict(),
shared_parameters::Dict = Dict(),
save_history::Bool = true,
)

##Create action model struct
Expand All @@ -96,6 +97,7 @@ function init_agent(
initial_state_parameters = Dict(),
states = Dict(),
settings = settings,
save_history = save_history,
)


Expand Down Expand Up @@ -199,6 +201,7 @@ function init_agent(
parameters::Dict = Dict(),
states::Dict = Dict(),
settings::Dict = Dict(),
save_history::Bool = true,
)

#If a setting called action_models has been specified manually
Expand All @@ -221,6 +224,7 @@ function init_agent(
parameters = parameters,
states = states,
settings = settings,
save_history = save_history,
)

return agent
Expand Down
16 changes: 9 additions & 7 deletions src/fitting/fit_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ fit_model(agent, priors, inputs, actions, n_chains = 1, n_iterations = 10)
```
"""
function fit_model(
agent::Agent,
original_agent::Agent,
priors::Dict,
data::DataFrame;
independent_group_cols::Vector = [],
Expand Down Expand Up @@ -66,11 +66,8 @@ function fit_model(
input_cols = Symbol.(input_cols)
action_cols = Symbol.(action_cols)

## Store old parameters for resetting the agent later ##
old_parameters = get_parameters(agent)

## Set fixed parameters to agent ##
set_parameters!(agent, fixed_parameters)
## Copy the agent to avoid changing the original ##
agent = deepcopy(original_agent)

## Run checks ##
prefit_checks(
Expand All @@ -82,11 +79,16 @@ function fit_model(
input_cols = input_cols,
action_cols = action_cols,
fixed_parameters = fixed_parameters,
old_parameters = old_parameters,
n_cores = n_cores,
verbose = verbose,
)

## Set save_history to false ##
set_save_history!(agent, false)

## Set fixed parameters to agent ##
set_parameters!(agent, fixed_parameters)

## Set logger ##
#If sample rejection warnings are to be shown
if show_sample_rejections
Expand Down
3 changes: 1 addition & 2 deletions src/fitting/prefit_checks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ function prefit_checks(;
input_cols,
action_cols,
fixed_parameters,
old_parameters,
n_cores,
verbose = true,
)
Expand All @@ -20,7 +19,7 @@ function prefit_checks(;
#If there are any of the agent's parameters which have not been set in the fixed or sampled parameters
if any(
key -> !(key in keys(priors)) && !(key in keys(fixed_parameters)),
keys(old_parameters),
keys(agent.parameters),
)
@warn "the agent has parameters which are not specified in the fixed or sampled parameters. The agent's current parameter values are used as fixed parameters"
end
Expand Down
14 changes: 7 additions & 7 deletions src/premade_models/binary_rescorla_wagner_softmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,12 @@ function binary_rescorla_wagner_softmax(agent::Agent, input::Union{Bool,Integer}
action_distribution = Distributions.Bernoulli(action_probability)

#Update states
agent.states["value"] = new_value
agent.states["value_probability"] = 1 / (1 + exp(-new_value))
agent.states["action_probability"] = action_probability
#Add to history
push!(agent.history["value"], new_value)
push!(agent.history["value_probability"], 1 / (1 + exp(-new_value)))
push!(agent.history["action_probability"], action_probability)
update_states!(agent, Dict(
"value" => new_value,
"value_probability" => 1 / (1 + exp(-new_value)),
"action_probability" => action_probability,
"input" => input,
))

return action_distribution
end
Expand Down Expand Up @@ -77,6 +76,7 @@ function premade_binary_rescorla_wagner_softmax(config::Dict)
"value" => missing,
"value_probability" => missing,
"action_probability" => missing,
"input" => missing,
)
settings = Dict()

Expand Down
10 changes: 4 additions & 6 deletions src/premade_models/continuous_rescorla_wagner_gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@ function continuous_rescorla_wagner_gaussian(agent::Agent, input::Real)
action_distribution = Distributions.Normal(new_value, action_noise)

##Update the states and save them to agent's history

agent.states["value"] = new_value
agent.states["input"] = input

push!(agent.history["value"], new_value)
push!(agent.history["input"], input)
update_states!(agent, Dict(
"value" => new_value,
"input" => input,
))

## return the action distribution to sample actions from
return action_distribution
Expand Down
5 changes: 3 additions & 2 deletions src/structs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ Base.@kwdef mutable struct Agent
substruct::Any
parameters::Dict = Dict()
initial_state_parameters::Dict{String,Any} = Dict()
states::Dict{String,Any} = Dict("action" => missing)
settings::Dict{String,Any} = Dict()
shared_parameters::Dict = Dict()
states::Dict{String,Any} = Dict("action" => missing)
history::Dict{String,Vector{Any}} = Dict("action" => [missing])
settings::Dict{String,Any} = Dict()
save_history::Bool = true
end


Expand Down
14 changes: 14 additions & 0 deletions src/utils/set_save_history.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#Function for changing the save_history setting
function set_save_history!(agent::Agent, save_history::Bool)

#Change it in the agent
agent.save_history = save_history

#And in its substruct
set_save_history!(agent.substruct, save_history)
end

#If there is an empty substruct, do nothing
function set_save_history!(substruct::Nothing, save_history::Bool)
return nothing
end
20 changes: 20 additions & 0 deletions src/utils/update_states.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
### Function which updates a single state, and saves it history
function update_states!(agent::Agent, state::String, value)
#Update state
agent.states[state] = value

#Save to history
if agent.save_history
push!(agent.history[state], value)
end
end


### Function which updates a dictionary of states to their values
function update_states!(agent::Agent, states::Dict)
#For each state and value
for (state, value) in states
#Update it
update_states!(agent, state, value)
end
end

0 comments on commit ee370b6

Please sign in to comment.