Skip to content

Commit

Permalink
Merge pull request #89 from zeyus/main
Browse files Browse the repository at this point in the history
Change default AD, add plots deps, less restrictive compatibility requirements
  • Loading branch information
PTWaade authored Mar 1, 2024
2 parents af025cc + 9d0b70c commit 9636064
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 11 deletions.
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
name = "ActionModels"
uuid = "320cf53b-cc3b-4b34-9a10-0ecb113566a3"
authors = ["Peter Thestrup Waade <peter@waade.net>"]
version = "0.4.2"
version = "0.4.3"

[deps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
DataFrames = "1"
Distributions = "0.25"
Distributions = "0"
RecipesBase = "1"
Turing = "0.29"
Turing = "≥ 0.30.2"
julia = "1.9"
4 changes: 2 additions & 2 deletions src/ActionModels.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module ActionModels

#Load packages
using Turing, Distributions, DataFrames, RecipesBase, Logging, Distributed

using ReverseDiff, Turing, Distributions, DataFrames, RecipesBase, Logging, Distributed
using Turing: DynamicPPL, AutoReverseDiff
#Export functions
export Agent, RejectParameters, SharedParameter, Multilevel
export init_agent, premade_agent, warn_premade_defaults, multiple_actions, check_agent
Expand Down
17 changes: 12 additions & 5 deletions src/fitting/fit_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ Use Turing to fit the parameters of an agent to a set of inputs and correspondin
- 'inputs:Array': array of inputs. Each row is a timestep, and each column is a single input value.
- 'actions::Array': array of actions. Each row is a timestep, and each column is a single action.
- 'fixed_parameters::Dict = Dict()': dictionary containing parameter values for parameters that are not fitted. Keys are parameter names, values are priors. For parameters not specified here and without priors, the parameter values of the agent are used instead.
- 'sampler = NUTS()': specify the type of Turing sampler.
- 'sampler::Union{Missing, DynamicPPL.Sampler} = missing': specify the type of Turing sampler. Defaults to `NUTS(-1, 0.65; adtype=AutoReverseDiff(true))`
- 'n_cores = 1': set number of cores to use for parallelization. If set to 1, no parallelization is used.
- 'n_iterations = 1000': set number of iterations per chain.
- 'n_chains = 2': set number of amount of chains.
- 'verbose = true': set to false to hide warnings.
- 'show_sample_rejections = false': set whether to show warnings whenever samples are rejected.
- 'impute_missing_actions = false': set whether the values of missing actions should also be estimated by Turing.
- 'sampler_kwargs...': additional keyword arguments to be passed to the `sampler`` call.
# Examples
```julia
Expand All @@ -45,7 +46,7 @@ function fit_model(
input_cols::Vector = [:input],
action_cols::Vector = [:action],
fixed_parameters::Dict = Dict(),
sampler = NUTS(),
sampler::Union{Missing, DynamicPPL.Sampler} = missing,
n_cores::Integer = 1,
n_iterations::Integer = 1000,
n_chains::Integer = 2,
Expand All @@ -55,7 +56,10 @@ function fit_model(
sampler_kwargs...,
)
### SETUP ###

# If no sampler has been specified, use NUTS
if ismissing(sampler)
sampler = NUTS(-1, 0.65; adtype=AutoReverseDiff(true))
end
#Convert column names to symbols
independent_group_cols = Symbol.(independent_group_cols)
multilevel_group_cols = Symbol.(multilevel_group_cols)
Expand Down Expand Up @@ -293,13 +297,14 @@ Use Turing to fit the parameters of an agent to a set of inputs and correspondin
- 'inputs:Array': array of inputs. Each row is a timestep, and each column is a single input value.
- 'actions::Array': array of actions. Each row is a timestep, and each column is a single action.
- 'fixed_parameters::Dict = Dict()': dictionary containing parameter values for parameters that are not fitted. Keys are parameter names, values are priors. For parameters not specified here and without priors, the parameter values of the agent are used instead.
- 'sampler = NUTS()': specify the type of Turing sampler.
- 'sampler::Union{Missing, DynamicPPL.Sampler} = missing': specify the type of Turing sampler, defaults to `NUTS(-1, 0.65; adtype=AutoReverseDiff(true))`
- 'n_cores = 1': set number of cores to use for parallelization. If set to 1, no parallelization is used.
- 'n_iterations = 1000': set number of iterations per chain.
- 'n_chains = 2': set number of amount of chains.
- 'verbose = true': set to false to hide warnings.
- 'show_sample_rejections = false': set whether to show warnings whenever samples are rejected.
- 'impute_missing_actions = false': set whether the values of missing actions should also be estimated by Turing.
- 'sampler_kwargs...': additional keyword arguments to be passed to the `sampler`` call.
# Examples
```julia
#Create a premade agent: binary Rescorla-Wagner
Expand All @@ -319,13 +324,14 @@ function fit_model(
inputs::Array,
actions::Array;
fixed_parameters::Dict = Dict(),
sampler = NUTS(),
sampler::Union{Missing, DynamicPPL.Sampler} = missing,
n_cores::Integer = 1,
n_iterations::Integer = 1000,
n_chains = 2,
verbose = true,
show_sample_rejections = false,
impute_missing_actions::Bool = false,
sampler_kwargs...,
)

#Create column names
Expand All @@ -350,6 +356,7 @@ function fit_model(
verbose = verbose,
show_sample_rejections = show_sample_rejections,
impute_missing_actions = impute_missing_actions,
sampler_kwargs...,
)

return chains
Expand Down
2 changes: 1 addition & 1 deletion test/testsuite/HGF_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ agent = premade_agent(
verbose = false,
)

priors = Dict(("x1", "evolution_rate") => Normal(-5, 1))
priors = Dict(("x1", "volatility") => Normal(-5, 1))

inputs = [1, 1.2, 1.4]

Expand Down

0 comments on commit 9636064

Please sign in to comment.