From 0f48883c6025035dd0d3ac2c76d4b8a9b8dbc35a Mon Sep 17 00:00:00 2001 From: zeyus Date: Tue, 6 Feb 2024 00:59:51 +0100 Subject: [PATCH 1/5] Update Project.toml --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 386ba08..f4d4c1b 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,7 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] DataFrames = "1" -Distributions = "0.25" +Distributions = "0" RecipesBase = "1" -Turing = "0.29" +Turing = "0" julia = "1.9" From 277ec8bbfe000708b0d9424217381dc076b66d61 Mon Sep 17 00:00:00 2001 From: zeyus Date: Tue, 6 Feb 2024 01:03:08 +0100 Subject: [PATCH 2/5] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f4d4c1b..72a0544 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ActionModels" uuid = "320cf53b-cc3b-4b34-9a10-0ecb113566a3" authors = ["Peter Thestrup Waade "] -version = "0.4.2" +version = "0.4.3" [deps] DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" From 0a36e9c1b8d58e9d060a8020979477397e94fb94 Mon Sep 17 00:00:00 2001 From: zeyus Date: Tue, 27 Feb 2024 23:02:58 +0100 Subject: [PATCH 3/5] Changed default sampler config, Updated test --- Project.toml | 3 +++ src/ActionModels.jl | 4 ++-- src/fitting/fit_model.jl | 17 ++++++++++++----- test/testsuite/HGF_tests.jl | 2 +- 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 72a0544..37511c9 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,10 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] diff --git a/src/ActionModels.jl b/src/ActionModels.jl index 6d953c7..34de0d5 100644 --- a/src/ActionModels.jl +++ b/src/ActionModels.jl @@ -1,8 +1,8 @@ module ActionModels #Load packages -using Turing, Distributions, DataFrames, RecipesBase, Logging, Distributed - +using ReverseDiff, Turing, Distributions, DataFrames, RecipesBase, Logging, Distributed +using Turing: DynamicPPL, AutoReverseDiff #Export functions export Agent, RejectParameters, SharedParameter, Multilevel export init_agent, premade_agent, warn_premade_defaults, multiple_actions, check_agent diff --git a/src/fitting/fit_model.jl b/src/fitting/fit_model.jl index 53afbbf..928cf18 100644 --- a/src/fitting/fit_model.jl +++ b/src/fitting/fit_model.jl @@ -12,13 +12,14 @@ Use Turing to fit the parameters of an agent to a set of inputs and correspondin - 'inputs:Array': array of inputs. Each row is a timestep, and each column is a single input value. - 'actions::Array': array of actions. Each row is a timestep, and each column is a single action. - 'fixed_parameters::Dict = Dict()': dictionary containing parameter values for parameters that are not fitted. Keys are parameter names, values are priors. For parameters not specified here and without priors, the parameter values of the agent are used instead. - - 'sampler = NUTS()': specify the type of Turing sampler. + - 'sampler::Union{Missing, DynamicPPL.Sampler} = missing': specify the type of Turing sampler. Defaults to `NUTS(-1, 0.65; adtype=AutoReverseDiff(true))` - 'n_cores = 1': set number of cores to use for parallelization. If set to 1, no parallelization is used. - 'n_iterations = 1000': set number of iterations per chain. - 'n_chains = 2': set number of amount of chains. - 'verbose = true': set to false to hide warnings. - 'show_sample_rejections = false': set whether to show warnings whenever samples are rejected. - 'impute_missing_actions = false': set whether the values of missing actions should also be estimated by Turing. + - 'sampler_kwargs...': additional keyword arguments to be passed to the `sampler`` call. # Examples ```julia @@ -45,7 +46,7 @@ function fit_model( input_cols::Vector = [:input], action_cols::Vector = [:action], fixed_parameters::Dict = Dict(), - sampler = NUTS(), + sampler::Union{Missing, DynamicPPL.Sampler} = missing, n_cores::Integer = 1, n_iterations::Integer = 1000, n_chains::Integer = 2, @@ -55,7 +56,10 @@ function fit_model( sampler_kwargs..., ) ### SETUP ### - + # If no sampler has been specified, use NUTS + if ismissing(sampler) + sampler = NUTS(-1, 0.65; adtype=AutoReverseDiff(true)) + end #Convert column names to symbols independent_group_cols = Symbol.(independent_group_cols) multilevel_group_cols = Symbol.(multilevel_group_cols) @@ -293,13 +297,14 @@ Use Turing to fit the parameters of an agent to a set of inputs and correspondin - 'inputs:Array': array of inputs. Each row is a timestep, and each column is a single input value. - 'actions::Array': array of actions. Each row is a timestep, and each column is a single action. - 'fixed_parameters::Dict = Dict()': dictionary containing parameter values for parameters that are not fitted. Keys are parameter names, values are priors. For parameters not specified here and without priors, the parameter values of the agent are used instead. - - 'sampler = NUTS()': specify the type of Turing sampler. + - 'sampler::Union{Missing, DynamicPPL.Sampler} = missing': specify the type of Turing sampler, defaults to `NUTS(-1, 0.65; adtype=AutoReverseDiff(true))` - 'n_cores = 1': set number of cores to use for parallelization. If set to 1, no parallelization is used. - 'n_iterations = 1000': set number of iterations per chain. - 'n_chains = 2': set number of amount of chains. - 'verbose = true': set to false to hide warnings. - 'show_sample_rejections = false': set whether to show warnings whenever samples are rejected. - 'impute_missing_actions = false': set whether the values of missing actions should also be estimated by Turing. + - 'sampler_kwargs...': additional keyword arguments to be passed to the `sampler`` call. # Examples ```julia #Create a premade agent: binary Rescorla-Wagner @@ -319,13 +324,14 @@ function fit_model( inputs::Array, actions::Array; fixed_parameters::Dict = Dict(), - sampler = NUTS(), + sampler::Union{Missing, DynamicPPL.Sampler} = missing, n_cores::Integer = 1, n_iterations::Integer = 1000, n_chains = 2, verbose = true, show_sample_rejections = false, impute_missing_actions::Bool = false, + sampler_kwargs..., ) #Create column names @@ -350,6 +356,7 @@ function fit_model( verbose = verbose, show_sample_rejections = show_sample_rejections, impute_missing_actions = impute_missing_actions, + sampler_kwargs..., ) return chains diff --git a/test/testsuite/HGF_tests.jl b/test/testsuite/HGF_tests.jl index eb591a8..98f2403 100644 --- a/test/testsuite/HGF_tests.jl +++ b/test/testsuite/HGF_tests.jl @@ -10,7 +10,7 @@ agent = premade_agent( verbose = false, ) -priors = Dict(("x1", "evolution_rate") => Normal(-5, 1)) +priors = Dict(("x1", "volatility") => Normal(-5, 1)) inputs = [1, 1.2, 1.4] From 151a2373afa56f2dc2788d577e636d4663462830 Mon Sep 17 00:00:00 2001 From: zeyus Date: Fri, 1 Mar 2024 09:38:40 +0100 Subject: [PATCH 4/5] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 37511c9..4e062f9 100644 --- a/Project.toml +++ b/Project.toml @@ -18,5 +18,5 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" DataFrames = "1" Distributions = "0" RecipesBase = "1" -Turing = "0" +Turing = "≥ 0.3.2" julia = "1.9" From 9d0b70c1b515e1687c96ea27d43dc3d8a1a44958 Mon Sep 17 00:00:00 2001 From: zeyus Date: Fri, 1 Mar 2024 09:39:55 +0100 Subject: [PATCH 5/5] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4e062f9..ff1c827 100644 --- a/Project.toml +++ b/Project.toml @@ -18,5 +18,5 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" DataFrames = "1" Distributions = "0" RecipesBase = "1" -Turing = "≥ 0.3.2" +Turing = "≥ 0.30.2" julia = "1.9"