diff --git a/Project.toml b/Project.toml index 1f53a9c..798cd50 100644 --- a/Project.toml +++ b/Project.toml @@ -8,12 +8,15 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] DataFrames = "1" -Distributions = "0.25" +Distributions = "0" RecipesBase = "1" -Turing = "0.29" +Turing = "≥ 0.30.2" julia = "1.9" diff --git a/src/ActionModels.jl b/src/ActionModels.jl index 499e2e2..5f15a64 100644 --- a/src/ActionModels.jl +++ b/src/ActionModels.jl @@ -1,8 +1,8 @@ module ActionModels #Load packages -using Turing, Distributions, DataFrames, RecipesBase, Logging, Distributed - +using ReverseDiff, Turing, Distributions, DataFrames, RecipesBase, Logging, Distributed +using Turing: DynamicPPL, AutoReverseDiff #Export functions export Agent, RejectParameters, SharedParameter, Multilevel export init_agent, premade_agent, warn_premade_defaults, multiple_actions, check_agent diff --git a/src/fitting/fit_model.jl b/src/fitting/fit_model.jl index 53afbbf..928cf18 100644 --- a/src/fitting/fit_model.jl +++ b/src/fitting/fit_model.jl @@ -12,13 +12,14 @@ Use Turing to fit the parameters of an agent to a set of inputs and correspondin - 'inputs:Array': array of inputs. Each row is a timestep, and each column is a single input value. - 'actions::Array': array of actions. Each row is a timestep, and each column is a single action. - 'fixed_parameters::Dict = Dict()': dictionary containing parameter values for parameters that are not fitted. Keys are parameter names, values are priors. For parameters not specified here and without priors, the parameter values of the agent are used instead. - - 'sampler = NUTS()': specify the type of Turing sampler. + - 'sampler::Union{Missing, DynamicPPL.Sampler} = missing': specify the type of Turing sampler. Defaults to `NUTS(-1, 0.65; adtype=AutoReverseDiff(true))` - 'n_cores = 1': set number of cores to use for parallelization. If set to 1, no parallelization is used. - 'n_iterations = 1000': set number of iterations per chain. - 'n_chains = 2': set number of amount of chains. - 'verbose = true': set to false to hide warnings. - 'show_sample_rejections = false': set whether to show warnings whenever samples are rejected. - 'impute_missing_actions = false': set whether the values of missing actions should also be estimated by Turing. + - 'sampler_kwargs...': additional keyword arguments to be passed to the `sampler`` call. # Examples ```julia @@ -45,7 +46,7 @@ function fit_model( input_cols::Vector = [:input], action_cols::Vector = [:action], fixed_parameters::Dict = Dict(), - sampler = NUTS(), + sampler::Union{Missing, DynamicPPL.Sampler} = missing, n_cores::Integer = 1, n_iterations::Integer = 1000, n_chains::Integer = 2, @@ -55,7 +56,10 @@ function fit_model( sampler_kwargs..., ) ### SETUP ### - + # If no sampler has been specified, use NUTS + if ismissing(sampler) + sampler = NUTS(-1, 0.65; adtype=AutoReverseDiff(true)) + end #Convert column names to symbols independent_group_cols = Symbol.(independent_group_cols) multilevel_group_cols = Symbol.(multilevel_group_cols) @@ -293,13 +297,14 @@ Use Turing to fit the parameters of an agent to a set of inputs and correspondin - 'inputs:Array': array of inputs. Each row is a timestep, and each column is a single input value. - 'actions::Array': array of actions. Each row is a timestep, and each column is a single action. - 'fixed_parameters::Dict = Dict()': dictionary containing parameter values for parameters that are not fitted. Keys are parameter names, values are priors. For parameters not specified here and without priors, the parameter values of the agent are used instead. - - 'sampler = NUTS()': specify the type of Turing sampler. + - 'sampler::Union{Missing, DynamicPPL.Sampler} = missing': specify the type of Turing sampler, defaults to `NUTS(-1, 0.65; adtype=AutoReverseDiff(true))` - 'n_cores = 1': set number of cores to use for parallelization. If set to 1, no parallelization is used. - 'n_iterations = 1000': set number of iterations per chain. - 'n_chains = 2': set number of amount of chains. - 'verbose = true': set to false to hide warnings. - 'show_sample_rejections = false': set whether to show warnings whenever samples are rejected. - 'impute_missing_actions = false': set whether the values of missing actions should also be estimated by Turing. + - 'sampler_kwargs...': additional keyword arguments to be passed to the `sampler`` call. # Examples ```julia #Create a premade agent: binary Rescorla-Wagner @@ -319,13 +324,14 @@ function fit_model( inputs::Array, actions::Array; fixed_parameters::Dict = Dict(), - sampler = NUTS(), + sampler::Union{Missing, DynamicPPL.Sampler} = missing, n_cores::Integer = 1, n_iterations::Integer = 1000, n_chains = 2, verbose = true, show_sample_rejections = false, impute_missing_actions::Bool = false, + sampler_kwargs..., ) #Create column names @@ -350,6 +356,7 @@ function fit_model( verbose = verbose, show_sample_rejections = show_sample_rejections, impute_missing_actions = impute_missing_actions, + sampler_kwargs..., ) return chains