Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change default AD, add plots deps, less restrictive compatibility requirements #89

Merged
merged 5 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
name = "ActionModels"
uuid = "320cf53b-cc3b-4b34-9a10-0ecb113566a3"
authors = ["Peter Thestrup Waade <peter@waade.net>"]
version = "0.4.2"
version = "0.4.3"

[deps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
DataFrames = "1"
Distributions = "0.25"
Distributions = "0"
RecipesBase = "1"
Turing = "0.29"
Turing = "≥ 0.30.2"
julia = "1.9"
4 changes: 2 additions & 2 deletions src/ActionModels.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module ActionModels

#Load packages
using Turing, Distributions, DataFrames, RecipesBase, Logging, Distributed

using ReverseDiff, Turing, Distributions, DataFrames, RecipesBase, Logging, Distributed
using Turing: DynamicPPL, AutoReverseDiff
#Export functions
export Agent, RejectParameters, SharedParameter, Multilevel
export init_agent, premade_agent, warn_premade_defaults, multiple_actions, check_agent
Expand Down
17 changes: 12 additions & 5 deletions src/fitting/fit_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ Use Turing to fit the parameters of an agent to a set of inputs and correspondin
- 'inputs:Array': array of inputs. Each row is a timestep, and each column is a single input value.
- 'actions::Array': array of actions. Each row is a timestep, and each column is a single action.
- 'fixed_parameters::Dict = Dict()': dictionary containing parameter values for parameters that are not fitted. Keys are parameter names, values are priors. For parameters not specified here and without priors, the parameter values of the agent are used instead.
- 'sampler = NUTS()': specify the type of Turing sampler.
- 'sampler::Union{Missing, DynamicPPL.Sampler} = missing': specify the type of Turing sampler. Defaults to `NUTS(-1, 0.65; adtype=AutoReverseDiff(true))`
- 'n_cores = 1': set number of cores to use for parallelization. If set to 1, no parallelization is used.
- 'n_iterations = 1000': set number of iterations per chain.
- 'n_chains = 2': set number of amount of chains.
- 'verbose = true': set to false to hide warnings.
- 'show_sample_rejections = false': set whether to show warnings whenever samples are rejected.
- 'impute_missing_actions = false': set whether the values of missing actions should also be estimated by Turing.
- 'sampler_kwargs...': additional keyword arguments to be passed to the `sampler`` call.

# Examples
```julia
Expand All @@ -45,7 +46,7 @@ function fit_model(
input_cols::Vector = [:input],
action_cols::Vector = [:action],
fixed_parameters::Dict = Dict(),
sampler = NUTS(),
sampler::Union{Missing, DynamicPPL.Sampler} = missing,
n_cores::Integer = 1,
n_iterations::Integer = 1000,
n_chains::Integer = 2,
Expand All @@ -55,7 +56,10 @@ function fit_model(
sampler_kwargs...,
)
### SETUP ###

# If no sampler has been specified, use NUTS
if ismissing(sampler)
sampler = NUTS(-1, 0.65; adtype=AutoReverseDiff(true))
end
#Convert column names to symbols
independent_group_cols = Symbol.(independent_group_cols)
multilevel_group_cols = Symbol.(multilevel_group_cols)
Expand Down Expand Up @@ -293,13 +297,14 @@ Use Turing to fit the parameters of an agent to a set of inputs and correspondin
- 'inputs:Array': array of inputs. Each row is a timestep, and each column is a single input value.
- 'actions::Array': array of actions. Each row is a timestep, and each column is a single action.
- 'fixed_parameters::Dict = Dict()': dictionary containing parameter values for parameters that are not fitted. Keys are parameter names, values are priors. For parameters not specified here and without priors, the parameter values of the agent are used instead.
- 'sampler = NUTS()': specify the type of Turing sampler.
- 'sampler::Union{Missing, DynamicPPL.Sampler} = missing': specify the type of Turing sampler, defaults to `NUTS(-1, 0.65; adtype=AutoReverseDiff(true))`
- 'n_cores = 1': set number of cores to use for parallelization. If set to 1, no parallelization is used.
- 'n_iterations = 1000': set number of iterations per chain.
- 'n_chains = 2': set number of amount of chains.
- 'verbose = true': set to false to hide warnings.
- 'show_sample_rejections = false': set whether to show warnings whenever samples are rejected.
- 'impute_missing_actions = false': set whether the values of missing actions should also be estimated by Turing.
- 'sampler_kwargs...': additional keyword arguments to be passed to the `sampler`` call.
# Examples
```julia
#Create a premade agent: binary Rescorla-Wagner
Expand All @@ -319,13 +324,14 @@ function fit_model(
inputs::Array,
actions::Array;
fixed_parameters::Dict = Dict(),
sampler = NUTS(),
sampler::Union{Missing, DynamicPPL.Sampler} = missing,
n_cores::Integer = 1,
n_iterations::Integer = 1000,
n_chains = 2,
verbose = true,
show_sample_rejections = false,
impute_missing_actions::Bool = false,
sampler_kwargs...,
)

#Create column names
Expand All @@ -350,6 +356,7 @@ function fit_model(
verbose = verbose,
show_sample_rejections = show_sample_rejections,
impute_missing_actions = impute_missing_actions,
sampler_kwargs...,
)

return chains
Expand Down
2 changes: 1 addition & 1 deletion test/testsuite/HGF_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ agent = premade_agent(
verbose = false,
)

priors = Dict(("x1", "evolution_rate") => Normal(-5, 1))
priors = Dict(("x1", "volatility") => Normal(-5, 1))

inputs = [1, 1.2, 1.4]

Expand Down
Loading