Skip to content

Commit

Permalink
Refactored oremade models, added rescorla wagner
Browse files Browse the repository at this point in the history
  • Loading branch information
PTWaade committed Apr 14, 2024
1 parent 869b0af commit 22a1bc8
Show file tree
Hide file tree
Showing 33 changed files with 157 additions and 150 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Find premade agent, and define agent with default parameters
````@example Introduction
premade_agent("help")
agent = premade_agent("premade_binary_rw_softmax")
agent = premade_agent("premade_binary_rescorla_wagner_softmax")
````

Set inputs and give inputs to agent
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
HierarchicalGaussianFiltering = "63d42c3e-681c-42be-892f-a47f35336a79"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/Using_the_package/Introduction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using ActionModels
# Find premade agent, and define agent with default parameters
premade_agent("help")

agent = premade_agent("premade_binary_rw_softmax")
agent = premade_agent("premade_binary_rescorla_wagner_softmax")

# Set inputs and give inputs to agent

Expand Down
2 changes: 1 addition & 1 deletion docs/src/Using_the_package/Simulation_with_an_agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# Let us define our agent and use the dedault parameter configurations
using ActionModels

agent = premade_agent("premade_binary_rw_softmax")
agent = premade_agent("premade_binary_rescorla_wagner_softmax")

# ## Give a single input
# we can now give the agent a single input with the give_inputs!() function. The inputs for the Rescorla-Wagner agent are binary, so we input the value 1.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ using ActionModels
using Distributions
using StatsPlots

agent = premade_agent("premade_binary_rw_softmax")
agent = premade_agent("premade_binary_rescorla_wagner_softmax")

# Let's give the agent some input and simulate a set of actions:

Expand Down
4 changes: 2 additions & 2 deletions docs/src/Using_the_package/premade_agents_and_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using ActionModels #hide
premade_agent("help")

# Lets create an agent. We will use the "premade\_binary\_rw\_softmax" agent with the "binary\_rw\_softmax" action model. You define a default premade agent with the syntax below:
agent = premade_agent("premade_binary_rw_softmax")
agent = premade_agent("premade_binary_rescorla_wagner_softmax")

# In the Actionmodels.jl package, an agent struct consists of the action model name (which can be premade or custom), parameters and states.
# The premade agents are initialized with a set of configurations for parameters, states and initial state parameters.
Expand Down Expand Up @@ -71,7 +71,7 @@ set_parameters!(

# If you know which parameter values you wish to use when defining your agent, you can specify them in the beginning as a dict() with parameter name as a string followed by the value.
agent_custom_parameters = premade_agent(
"premade_binary_rw_softmax",
"premade_binary_rescorla_wagner_softmax",
Dict(
"learning_rate" => 0.7,
"softmax_action_precision" => 0.8,
Expand Down
2 changes: 1 addition & 1 deletion docs/src/Using_the_package/prior_predictive_sim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
using ActionModels

#Define an agent
agent = premade_agent("premade_binary_rw_softmax")
agent = premade_agent("premade_binary_rescorla_wagner_softmax")
#Define input
inputs = [1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0]

Expand Down
2 changes: 1 addition & 1 deletion docs/src/Using_the_package/variations_of_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# We will define an agent to use during demonstrations of the utility functions:
using ActionModels #hide

agent = premade_agent("premade_binary_rw_softmax")
agent = premade_agent("premade_binary_rescorla_wagner_softmax")

# ## Getting States
# The get_states() function can give you a single state, multiple states and all states of an agent.
Expand Down
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Find premade agent, and define agent with default parameters
````@example index
premade_agent("help")
agent = premade_agent("premade_binary_rw_softmax")
agent = premade_agent("premade_binary_rescorla_wagner_softmax")
````

Set inputs and give inputs to agent
Expand Down
2 changes: 1 addition & 1 deletion docs/src/julia_src_files/fitting_an_agent_model_to_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ using ActionModels
using Distributions
using StatsPlots

agent = premade_agent("binary_rw_softmax")
agent = premade_agent("binary_rescorla_wagner_softmax")

# Let's give the agent some input and simulate a set of actions:

Expand Down
2 changes: 1 addition & 1 deletion docs/src/julia_src_files/index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using ActionModels
# Find premade agent, and define agent with default parameters
premade_agent("help")

agent = premade_agent("binary_rw_softmax")
agent = premade_agent("binary_rescorla_wagner_softmax")

# Set inputs and give inputs to agent

Expand Down
4 changes: 2 additions & 2 deletions docs/src/julia_src_files/premade_agents_and_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using ActionModels #hide
premade_agent("help")

# Lets create an agent. We will use the "premade\_binary\_rw\_softmax" agent with the "binary\_rw\_softmax" action model. You define a default premade agent with the syntax below:
agent = premade_agent("binary_rw_softmax")
agent = premade_agent("binary_rescorla_wagner_softmax")

# In the Actionmodels.jl package, an agent struct consists of the action model name (which can be premade or custom), parameters and states.
# The premade agents are initialized with a set of configurations for parameters, states and initial state parameters.
Expand Down Expand Up @@ -71,7 +71,7 @@ set_parameters!(

# If you know which parameter values you wish to use when defining your agent, you can specify them in the beginning as a dict() with parameter name as a string followed by the value.
agent_custom_parameters = premade_agent(
"binary_rw_softmax",
"binary_rescorla_wagner_softmax",
Dict(
"learning_rate" => 0.7,
"softmax_action_precision" => 0.8,
Expand Down
2 changes: 1 addition & 1 deletion docs/src/julia_src_files/prior_predictive_sim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
using ActionModels

#Define an agent
agent = premade_agent("binary_rw_softmax")
agent = premade_agent("binary_rescorla_wagner_softmax")
#Define input
inputs = [1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0]

Expand Down
2 changes: 1 addition & 1 deletion docs/src/julia_src_files/simulation_with_an_agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# Let us define our agent and use the dedault parameter configurations
using ActionModels

agent = premade_agent("binary_rw_softmax")
agent = premade_agent("binary_rescorla_wagner_softmax")

# ## Give a single input
# we can now give the agent a single input with the give_inputs!() function. The inputs for the Rescorla-Wagner agent are binary, so we input the value 1.
Expand Down
2 changes: 1 addition & 1 deletion docs/src/julia_src_files/variations_of_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# We will define an agent to use during demonstrations of the utility functions:
using ActionModels #hide

agent = premade_agent("binary_rw_softmax")
agent = premade_agent("binary_rescorla_wagner_softmax")

# ## Getting States
# The get_states() function can give you a single state, multiple states and all states of an agent.
Expand Down
2 changes: 1 addition & 1 deletion docs/src/markdowns/Introduction.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Find premade agent, and define agent with default parameters
````@example Introduction
premade_agent("help")
agent = premade_agent("premade_binary_rw_softmax")
agent = premade_agent("premade_binary_rescorla_wagner_softmax")
````

Set inputs and give inputs to agent
Expand Down
2 changes: 1 addition & 1 deletion docs/src/markdowns/fitting_an_agent_model_to_data.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ using ActionModels
using Distributions
using StatsPlots
agent = premade_agent("premade_binary_rw_softmax")
agent = premade_agent("premade_binary_rescorla_wagner_softmax")
````

Let's give the agent some input and simulate a set of actions:
Expand Down
4 changes: 2 additions & 2 deletions docs/src/markdowns/premade_agents_and_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ premade_agent("help")
Lets create an agent. We will use the "premade\_binary\_rw\_softmax" agent with the "binary\_rw\_softmax" action model. You define a default premade agent with the syntax below:

````@example premade_agents_and_models
agent = premade_agent("premade_binary_rw_softmax")
agent = premade_agent("premade_binary_rescorla_wagner_softmax")
````

In the Actionmodels.jl package, an agent struct consists of the action model name (which can be premade or custom), parameters and states.
Expand Down Expand Up @@ -88,7 +88,7 @@ set_parameters!(agent, Dict("learning_rate" => 0.79,
If you know which parameter values you wish to use when defining your agent, you can specify them in the beginning as a dict() with parameter name as a string followed by the value.

````@example premade_agents_and_models
agent_custom_parameters = premade_agent("premade_binary_rw_softmax", Dict("learning_rate" => 0.7,
agent_custom_parameters = premade_agent("premade_binary_rescorla_wagner_softmax", Dict("learning_rate" => 0.7,
"softmax_action_precision" => 0.8,
("initial", "value") => 1)
)
Expand Down
2 changes: 1 addition & 1 deletion docs/src/markdowns/prior_predictive_sim.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Let us go through a prior predictive simulation
using ActionModels
#Define an agent
agent = premade_agent("premade_binary_rw_softmax")
agent = premade_agent("premade_binary_rescorla_wagner_softmax")
#Define input
inputs = [1,0,0,1,1,1,1,0,1,0,1,0,0,1,1,0,1,0,0]
Expand Down
2 changes: 1 addition & 1 deletion docs/src/markdowns/simulation_with_an_agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Let us define our agent and use the dedault parameter configurations
````@example simulation_with_an_agent
using ActionModels
agent = premade_agent("premade_binary_rw_softmax")
agent = premade_agent("premade_binary_rescorla_wagner_softmax")
````

### Give a single input
Expand Down
2 changes: 1 addition & 1 deletion docs/src/markdowns/variations_of_util.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ We will define an agent to use during demonstrations of the utility functions:
````@example variations_of_util
using ActionModels #hide
agent = premade_agent("premade_binary_rw_softmax")
agent = premade_agent("premade_binary_rescorla_wagner_softmax")
````

### Getting States
Expand Down
13 changes: 7 additions & 6 deletions src/ActionModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ export plot_parameter_distribution,
export get_history, get_states, get_parameters, set_parameters!, reset!, give_inputs!, single_input!
export get_posteriors

#Load premade agents
function __init__()
premade_agents["binary_rw_softmax"] = premade_binary_rw_softmax
premade_agents["continuous_rescorla_wagner"] = premade_continuous_rescorla_wagner
premade_agents["binary_rescorla_wagner_softmax"] = premade_binary_rescorla_wagner_softmax
premade_agents["continuous_rescorla_wagner_gaussian"] = premade_continuous_rescorla_wagner_gaussian
end

#Types for agents and errors
Expand All @@ -36,10 +37,6 @@ include("plots/plot_predictive_simulation.jl")
include("plots/plot_parameter_distribution.jl")
include("plots/plot_trajectory.jl")

#Functions for making premade agent
include("premade_models/premade_agents.jl")
include("premade_models/premade_action_models.jl")

#Utility functions for agents
include("utils/get_history.jl")
include("utils/get_parameters.jl")
Expand All @@ -50,4 +47,8 @@ include("utils/set_parameters.jl")
include("utils/warn_premade_defaults.jl")
include("utils/get_posteriors.jl")
include("utils/pretty_printing.jl")

#Premade agents
include("premade_models/binary_rescorla_wagner_softmax.jl")
include("premade_models/continuous_rescorla_wagner_gaussian.jl")
end
4 changes: 2 additions & 2 deletions src/create_agent/init_agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ In this case the action models will be stored in the agent's settings. In that c
## Create agent with a binary Rescorla-Wagner action model ##
## Create action model function
function binary_rw_softmax(agent::Agent, input::Union{Bool,Integer})
function binary_rescorla_wagner_softmax(agent::Agent, input::Union{Bool,Integer})
#Read in parameters
learning_rate = agent.parameters["learning_rate"]
Expand Down Expand Up @@ -69,7 +69,7 @@ states = Dict(
#Create agent
agent = init_agent(
binary_rw_softmax,
binary_rescorla_wagner_softmax,
parameters = parameters,
states = states,
settings = settings,
Expand Down
4 changes: 2 additions & 2 deletions src/fitting/fit_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Use Turing to fit the parameters of an agent to a set of inputs and correspondin
# Examples
```julia
#Create a premade agent: binary Rescorla-Wagner
agent = premade_agent("premade_binary_rw_softmax")
agent = premade_agent("premade_binary_rescorla_wagner_softmax")
#Set priors for the learning rate
priors = Dict("learning_rate" => Uniform(0, 1))
Expand Down Expand Up @@ -308,7 +308,7 @@ Use Turing to fit the parameters of an agent to a set of inputs and correspondin
# Examples
```julia
#Create a premade agent: binary Rescorla-Wagner
agent = premade_agent("premade_binary_rw_softmax")
agent = premade_agent("premade_binary_rescorla_wagner_softmax")
#Set priors for the learning rate
param_priors = Dict("learning_rate" => Uniform(0, 1))
#Set inputs and actions
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""
binary_rw_softmax(agent::Agent, input::Bool)
binary_rescorla_wagner_softmax(agent::Agent, input::Bool)
Action model that learns from binary inputs with a classic Rescorla-Wagner model. Passes learnt probabilities through a softmax to get the action prpbability distribution.
Parameters: "learning_rate" and "softmax_action_precision".
States: "value", "value_probability", "action_probability".
"""
function binary_rw_softmax(agent::Agent, input::Union{Bool,Integer})
function binary_rescorla_wagner_softmax(agent::Agent, input::Union{Bool,Integer})

#Read in parameters
learning_rate = agent.parameters["learning_rate"]
Expand Down Expand Up @@ -39,34 +39,51 @@ function binary_rw_softmax(agent::Agent, input::Union{Bool,Integer})
return action_distribution
end

function continuous_rescorla_wagner(agent::Agent, input::Real)

## Read in parameters from the agent
learning_rate = agent.parameters["learning_rate"]
action_noise = agent.parameters["action_noise"]

## Read in states with an initial value
old_value = agent.states["value"]

##We dont have any settings in this model. If we had, we would read them in as well.
##-----This is where the update step starts -------

##Get new value state
new_value = old_value + learning_rate * (input - old_value)


##-----This is where the update step ends -------
##Create Bernoulli normal distribution our action probability which we calculated in the update step
action_distribution = Distributions.Normal(new_value, action_noise)

##Update the states and save them to agent's history
"""
premade_binary_rescorla_wagner_softmax(config::Dict)
agent.states["value"] = new_value
agent.states["input"] = input
Create premade agent that uses the binary_rescorla_wagner_softmax action model.
push!(agent.history["value"], new_value)
push!(agent.history["input"], input)
# Config defaults:
- "learning_rate": 1
- "softmax_action_precision": 1
- ("initial", "value"): 0
"""

## return the action distribution to sample actions from
return action_distribution
function premade_binary_rescorla_wagner_softmax(config::Dict)

#Default parameters and settings
default_config = Dict(
"learning_rate" => 1,
"softmax_action_precision" => 1,
("initial", "value") => 0,
)
#Warn the user about used defaults and misspecified keys
warn_premade_defaults(default_config, config)

#Merge to overwrite defaults
config = merge(default_config, config)

## Create agent
action_model = binary_rescorla_wagner_softmax
parameters = Dict(
"learning_rate" => config["learning_rate"],
"softmax_action_precision" => config["softmax_action_precision"],
("initial", "value") => config[("initial", "value")],
)
states = Dict(
"value" => missing,
"value_probability" => missing,
"action_probability" => missing,
)
settings = Dict()

return init_agent(
action_model,
parameters = parameters,
states = states,
settings = settings,
)
end
Loading

0 comments on commit 22a1bc8

Please sign in to comment.