Skip to content

Commit

Permalink
Merge pull request #90 from maltelau/pr-premade-agent
Browse files Browse the repository at this point in the history
premade agent: continuous rescorla wagner
  • Loading branch information
PTWaade authored Mar 1, 2024
2 parents 2fc24a7 + 964b488 commit 131b1ec
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/ActionModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export get_posteriors

function __init__()
premade_agents["binary_rw_softmax"] = premade_binary_rw_softmax
premade_agents["continuous_rescorla_wagner"] = premade_continuous_rescorla_wagner
end

#Types for agents and errors
Expand Down
32 changes: 32 additions & 0 deletions src/premade_models/premade_action_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,35 @@ function binary_rw_softmax(agent::Agent, input::Union{Bool,Integer})

return action_distribution
end

function continuous_rescorla_wagner(agent::Agent, input::Real)

## Read in parameters from the agent
learning_rate = agent.parameters["learning_rate"]
action_noise = agent.parameters["action_noise"]

## Read in states with an initial value
old_value = agent.states["value"]

##We dont have any settings in this model. If we had, we would read them in as well.
##-----This is where the update step starts -------

##Get new value state
new_value = old_value + learning_rate * (input - old_value)


##-----This is where the update step ends -------
##Create Bernoulli normal distribution our action probability which we calculated in the update step
action_distribution = Distributions.Normal(new_value, action_noise)

##Update the states and save them to agent's history

agent.states["value"] = new_value
agent.states["input"] = input

push!(agent.history["value"], new_value)
push!(agent.history["input"], input)

## return the action distribution to sample actions from
return action_distribution
end
37 changes: 37 additions & 0 deletions src/premade_models/premade_agents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,40 @@ function premade_binary_rw_softmax(config::Dict)
settings = settings,
)
end


function premade_continuous_rescorla_wagner(config::Dict)

#Default parameters and settings
default_config = Dict(
"learning_rate" => 0.1,
"action_noise" => 1,
("initial", "value") => 0,
)

#Warn the user about used defaults and misspecified keys
warn_premade_defaults(default_config, config)

#Merge to overwrite defaults
config = merge(default_config, config)

## Create agent
action_model = continuous_rescorla_wagner
parameters = Dict(
"learning_rate" => config["learning_rate"],
"action_noise" => config["action_noise"],
("initial", "value") => config[("initial", "value")],
)
states = Dict(
"input" => missing,
"value" => missing
)
settings = Dict()

return init_agent(
action_model,
parameters = parameters,
states = states,
settings = settings,
)
end

0 comments on commit 131b1ec

Please sign in to comment.