Skip to content

Commit

Permalink
Merge branch 'main' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
PTWaade committed Mar 1, 2024
2 parents ec56840 + 9636064 commit 4f09310
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
DataFrames = "1"
Distributions = "0.25"
Distributions = "0"
RecipesBase = "1"
Turing = "0.29"
Turing = "≥ 0.30.2"
julia = "1.9"
4 changes: 2 additions & 2 deletions src/ActionModels.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module ActionModels

#Load packages
using Turing, Distributions, DataFrames, RecipesBase, Logging, Distributed

using ReverseDiff, Turing, Distributions, DataFrames, RecipesBase, Logging, Distributed
using Turing: DynamicPPL, AutoReverseDiff
#Export functions
export Agent, RejectParameters, SharedParameter, Multilevel
export init_agent, premade_agent, warn_premade_defaults, multiple_actions, check_agent
Expand Down
17 changes: 12 additions & 5 deletions src/fitting/fit_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ Use Turing to fit the parameters of an agent to a set of inputs and correspondin
- 'inputs:Array': array of inputs. Each row is a timestep, and each column is a single input value.
- 'actions::Array': array of actions. Each row is a timestep, and each column is a single action.
- 'fixed_parameters::Dict = Dict()': dictionary containing parameter values for parameters that are not fitted. Keys are parameter names, values are priors. For parameters not specified here and without priors, the parameter values of the agent are used instead.
- 'sampler = NUTS()': specify the type of Turing sampler.
- 'sampler::Union{Missing, DynamicPPL.Sampler} = missing': specify the type of Turing sampler. Defaults to `NUTS(-1, 0.65; adtype=AutoReverseDiff(true))`
- 'n_cores = 1': set number of cores to use for parallelization. If set to 1, no parallelization is used.
- 'n_iterations = 1000': set number of iterations per chain.
- 'n_chains = 2': set number of amount of chains.
- 'verbose = true': set to false to hide warnings.
- 'show_sample_rejections = false': set whether to show warnings whenever samples are rejected.
- 'impute_missing_actions = false': set whether the values of missing actions should also be estimated by Turing.
- 'sampler_kwargs...': additional keyword arguments to be passed to the `sampler`` call.
# Examples
```julia
Expand All @@ -45,7 +46,7 @@ function fit_model(
input_cols::Vector = [:input],
action_cols::Vector = [:action],
fixed_parameters::Dict = Dict(),
sampler = NUTS(),
sampler::Union{Missing, DynamicPPL.Sampler} = missing,
n_cores::Integer = 1,
n_iterations::Integer = 1000,
n_chains::Integer = 2,
Expand All @@ -55,7 +56,10 @@ function fit_model(
sampler_kwargs...,
)
### SETUP ###

# If no sampler has been specified, use NUTS
if ismissing(sampler)
sampler = NUTS(-1, 0.65; adtype=AutoReverseDiff(true))
end
#Convert column names to symbols
independent_group_cols = Symbol.(independent_group_cols)
multilevel_group_cols = Symbol.(multilevel_group_cols)
Expand Down Expand Up @@ -293,13 +297,14 @@ Use Turing to fit the parameters of an agent to a set of inputs and correspondin
- 'inputs:Array': array of inputs. Each row is a timestep, and each column is a single input value.
- 'actions::Array': array of actions. Each row is a timestep, and each column is a single action.
- 'fixed_parameters::Dict = Dict()': dictionary containing parameter values for parameters that are not fitted. Keys are parameter names, values are priors. For parameters not specified here and without priors, the parameter values of the agent are used instead.
- 'sampler = NUTS()': specify the type of Turing sampler.
- 'sampler::Union{Missing, DynamicPPL.Sampler} = missing': specify the type of Turing sampler, defaults to `NUTS(-1, 0.65; adtype=AutoReverseDiff(true))`
- 'n_cores = 1': set number of cores to use for parallelization. If set to 1, no parallelization is used.
- 'n_iterations = 1000': set number of iterations per chain.
- 'n_chains = 2': set number of amount of chains.
- 'verbose = true': set to false to hide warnings.
- 'show_sample_rejections = false': set whether to show warnings whenever samples are rejected.
- 'impute_missing_actions = false': set whether the values of missing actions should also be estimated by Turing.
- 'sampler_kwargs...': additional keyword arguments to be passed to the `sampler`` call.
# Examples
```julia
#Create a premade agent: binary Rescorla-Wagner
Expand All @@ -319,13 +324,14 @@ function fit_model(
inputs::Array,
actions::Array;
fixed_parameters::Dict = Dict(),
sampler = NUTS(),
sampler::Union{Missing, DynamicPPL.Sampler} = missing,
n_cores::Integer = 1,
n_iterations::Integer = 1000,
n_chains = 2,
verbose = true,
show_sample_rejections = false,
impute_missing_actions::Bool = false,
sampler_kwargs...,
)

#Create column names
Expand All @@ -350,6 +356,7 @@ function fit_model(
verbose = verbose,
show_sample_rejections = show_sample_rejections,
impute_missing_actions = impute_missing_actions,
sampler_kwargs...,
)

return chains
Expand Down

0 comments on commit 4f09310

Please sign in to comment.