From 2c00dac95effca334a2e36d753f2bcda0ae2ff0f Mon Sep 17 00:00:00 2001 From: zeyus Date: Sat, 11 May 2024 15:34:58 +0200 Subject: [PATCH 1/5] Fixes #106, allow external samplers. --- src/fitting/fit_model.jl | 9 ++++----- test/Project.toml | 2 ++ test/testsuite/fitting_tests.jl | 29 ++++++++++++++++++++++++++++- 3 files changed, 34 insertions(+), 6 deletions(-) diff --git a/src/fitting/fit_model.jl b/src/fitting/fit_model.jl index e6037db..80ea9db 100644 --- a/src/fitting/fit_model.jl +++ b/src/fitting/fit_model.jl @@ -12,7 +12,7 @@ Use Turing to fit the parameters of an agent to a set of inputs and correspondin - 'inputs:Array': array of inputs. Each row is a timestep, and each column is a single input value. - 'actions::Array': array of actions. Each row is a timestep, and each column is a single action. - 'fixed_parameters::Dict = Dict()': dictionary containing parameter values for parameters that are not fitted. Keys are parameter names, values are priors. For parameters not specified here and without priors, the parameter values of the agent are used instead. - - 'sampler::Union{Missing, DynamicPPL.Sampler} = missing': specify the type of Turing sampler. Defaults to `NUTS(-1, 0.65; adtype=AutoReverseDiff(true))` + - 'sampler::Union{DynamicPPL.AbstractSampler, Turing.Inference.InferenceAlgorithm}': specify the type of Turing sampler. Defaults to `NUTS(-1, 0.65; adtype=AutoReverseDiff(true))` - 'n_cores = 1': set number of cores to use for parallelization. If set to 1, no parallelization is used. - 'n_iterations = 1000': set number of iterations per chain. - 'n_chains = 2': set number of amount of chains. @@ -46,7 +46,7 @@ function fit_model( input_cols::Vector = [:input], action_cols::Vector = [:action], fixed_parameters::Dict = Dict(), - sampler::Turing.Inference.InferenceAlgorithm = NUTS( + sampler::Union{DynamicPPL.AbstractSampler, Turing.Inference.InferenceAlgorithm} = NUTS( -1, 0.65; adtype = AutoReverseDiff(true), @@ -103,7 +103,6 @@ function fit_model( ## Store whether there are multiple inputs and actions ## multiple_inputs = length(input_cols) > 1 multiple_actions = length(action_cols) > 1 - multilevel = length(multilevel_group_cols) > 0 ## Structure multilevel parameter information ## general_parameters_info = extract_structured_parameter_info(; @@ -307,7 +306,7 @@ Use Turing to fit the parameters of an agent to a set of inputs and correspondin - 'inputs:Array': array of inputs. Each row is a timestep, and each column is a single input value. - 'actions::Array': array of actions. Each row is a timestep, and each column is a single action. - 'fixed_parameters::Dict = Dict()': dictionary containing parameter values for parameters that are not fitted. Keys are parameter names, values are priors. For parameters not specified here and without priors, the parameter values of the agent are used instead. - - 'sampler::Union{Missing, DynamicPPL.Sampler} = missing': specify the type of Turing sampler, defaults to `NUTS(-1, 0.65; adtype=AutoReverseDiff(true))` + - 'sampler::Union{DynamicPPL.AbstractSampler, Turing.Inference.InferenceAlgorithm}: specify the type of Turing sampler, defaults to `NUTS(-1, 0.65; adtype=AutoReverseDiff(true))` - 'n_cores = 1': set number of cores to use for parallelization. If set to 1, no parallelization is used. - 'n_iterations = 1000': set number of iterations per chain. - 'n_chains = 2': set number of amount of chains. @@ -334,7 +333,7 @@ function fit_model( inputs::Array, actions::Array; fixed_parameters::Dict = Dict(), - sampler::Turing.Inference.InferenceAlgorithm = NUTS( + sampler::Union{DynamicPPL.AbstractSampler, Turing.Inference.InferenceAlgorithm} = NUTS( -1, 0.65; adtype = AutoReverseDiff(true), diff --git a/test/Project.toml b/test/Project.toml index 48b9af1..3fff499 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,9 +1,11 @@ [deps] +AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" +MicroCanonicalHMC = "234d2aa0-2291-45f7-9047-6fa6f316b0a8" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/testsuite/fitting_tests.jl b/test/testsuite/fitting_tests.jl index b322794..7009159 100644 --- a/test/testsuite/fitting_tests.jl +++ b/test/testsuite/fitting_tests.jl @@ -1,11 +1,12 @@ +using AdvancedMH using ActionModels using Test using Distributions using DataFrames using Plots using StatsPlots - +using Turing: externalsampler @testset "simulate actions and fit" begin @@ -183,3 +184,29 @@ end @test get_parameters(agent) == initial_parameters end + + +@testset "Make sure fitting allows using a custom sampler" begin + + agent = premade_agent("binary_rescorla_wagner_softmax", verbose = false) + + param_priors = Dict("learning_rate" => Uniform(0, 1)) + + inputs = [1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0] + + actions = give_inputs!(agent, inputs) + + rwmh = externalsampler(AdvancedMH.RWMH(10)) + + chains = fit_model( + agent, + param_priors, + inputs, + actions, + n_chains = 1, + n_iterations = 10, + verbose = false, + sampler = ext_sampler, + ) + +end \ No newline at end of file From 8fc13eea1cc2c894d247f8c1ca3d3107ea6825bf Mon Sep 17 00:00:00 2001 From: zeyus Date: Sat, 11 May 2024 15:44:22 +0200 Subject: [PATCH 2/5] Update fitting_tests.jl --- test/testsuite/fitting_tests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/testsuite/fitting_tests.jl b/test/testsuite/fitting_tests.jl index 7009159..87ddf66 100644 --- a/test/testsuite/fitting_tests.jl +++ b/test/testsuite/fitting_tests.jl @@ -206,7 +206,7 @@ end n_chains = 1, n_iterations = 10, verbose = false, - sampler = ext_sampler, + sampler = rwmh, ) -end \ No newline at end of file +end From a5bc30cd60f05ea2ec6aa214dfd990671f9c9382 Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Sun, 12 May 2024 21:27:30 +0200 Subject: [PATCH 3/5] Added helper functions for dealing with re-using sampled actions. --- Project.toml | 17 ++++++++--------- src/ActionModels.jl | 3 ++- src/fitting/create_model.jl | 4 ++-- src/fitting/fitting_helper_functions.jl | 24 ++++++++++++++++++++++++ 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index d423c52..7024ff7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,26 +1,25 @@ name = "ActionModels" uuid = "320cf53b-cc3b-4b34-9a10-0ecb113566a3" -authors = ["Peter Thestrup Waade ptw@cas.au.dk", - "Anna Hedvig Møller hedvig.2808@gmail.com", - "Jacopo Comoglio jacopo.comoglio@gmail.com", - "Christoph Mathys chmathys@cas.au.dk"] +authors = ["Peter Thestrup Waade ptw@cas.au.dk", "Anna Hedvig Møller hedvig.2808@gmail.com", "Jacopo Comoglio jacopo.comoglio@gmail.com", "Christoph Mathys chmathys@cas.au.dk"] version = "0.5.2" [deps] DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" [compat] DataFrames = "1.6" +Distributed = "1.10" Distributions = "0.25" -RecipesBase = "1" +Logging = "1.10" +RecipesBase = "1.3" +ReverseDiff = "1.15" +ForwardDiff = "0.10" Turing = "0.31" julia = "1.9" -ReverseDiff = "1" -Distributed = "1.10" -Logging = "1.10" diff --git a/src/ActionModels.jl b/src/ActionModels.jl index 2451c49..2a86dad 100644 --- a/src/ActionModels.jl +++ b/src/ActionModels.jl @@ -1,7 +1,8 @@ module ActionModels #Load packages -using ReverseDiff, Turing, Distributions, DataFrames, RecipesBase, Logging, Distributed +using ReverseDiff, + ForwardDiff, Turing, Distributions, DataFrames, RecipesBase, Logging, Distributed using Turing: DynamicPPL, AutoReverseDiff #Export functions export Agent, RejectParameters, GroupedParameters, Multilevel diff --git a/src/fitting/create_model.jl b/src/fitting/create_model.jl index 3fe3a9b..38078bc 100644 --- a/src/fitting/create_model.jl +++ b/src/fitting/create_model.jl @@ -158,7 +158,7 @@ Create a Turing model object used for fitting an ActionModels agent. actions[group][timestep] ~ action_distribution #Save the action to the agent in case it needs it in the future - agent.states["action"] = actions[group][timestep] + agent.states["action"] = ad_val(actions[group][timestep]) #If there are multiple actions @@ -185,7 +185,7 @@ Create a Turing model object used for fitting an ActionModels agent. @inbounds actions[group][timestep, action_idx] ~ single_distribution #Save the action - push!(actions, actions[group][timestep, action_idx]) + push!(actions, ad_val(actions[group][timestep, action_idx])) end #Save the action to the agent, for models that need previous action diff --git a/src/fitting/fitting_helper_functions.jl b/src/fitting/fitting_helper_functions.jl index 8ba6dc0..375e4c5 100644 --- a/src/fitting/fitting_helper_functions.jl +++ b/src/fitting/fitting_helper_functions.jl @@ -315,3 +315,27 @@ function rename_chains(chains::Chains, independent_group_info::NamedTuple) return chains end + + + + + + + + +####### FUNCTIONS FOR EXTRACTING A VALUE WHICH WORKS WITH DIFFERENT AUTODIFFERENTIATION BACKENDS #### + +function ad_val(x::ReverseDiff.TrackedReal) + return ReverseDiff.value(x) +end +function ad_val(x::ReverseDiff.TrackedArray) + return ReverseDiff.value(x) +end + +function ad_val(x::ForwardDiff.Dual) + return ForwardDiff.value(x) +end + +function ad_val(x::Real) + return x +end From 35187f6574e1f0474aac91884277c823b7744888 Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Thu, 16 May 2024 12:07:04 +0200 Subject: [PATCH 4/5] made actionmdoels return vectors not subarrays --- src/fitting/create_model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fitting/create_model.jl b/src/fitting/create_model.jl index 38078bc..3ffbfed 100644 --- a/src/fitting/create_model.jl +++ b/src/fitting/create_model.jl @@ -131,7 +131,7 @@ Create a Turing model object used for fitting an ActionModels agent. iterator = enumerate(inputs[group]) else #Iterate over rows of inputs - iterator = enumerate(eachrow(inputs[group])) + iterator = enumerate(Vector.(eachrow(inputs[group]))) end #Go through each timestep From e0d8928ff3856d445ebc31db1984ab8fda56f25b Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Thu, 16 May 2024 12:31:11 +0200 Subject: [PATCH 5/5] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7024ff7..02e8d34 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ActionModels" uuid = "320cf53b-cc3b-4b34-9a10-0ecb113566a3" authors = ["Peter Thestrup Waade ptw@cas.au.dk", "Anna Hedvig Møller hedvig.2808@gmail.com", "Jacopo Comoglio jacopo.comoglio@gmail.com", "Christoph Mathys chmathys@cas.au.dk"] -version = "0.5.2" +version = "0.5.3" [deps] DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"