# T-Maze Multi-Agent Bargaining Simulation

In [None]:
using Pkg
Pkg.activate("..")
Pkg.instantiate()

In [None]:
using RxInfer, LinearAlgebra, Plots

include("helpers.jl")
include("../goal_observation.jl")
;

In [None]:
# Simulation parameters
αs = [0.8, 0.85, 0.9, 0.95, 1.0] # Possible offers
L = length(αs)
c = 2.0
S = 30
seed = 666
;

## Primary Agent

In [None]:
@model function t_maze_primary(A, D, x)
    u = datavar(Matrix{Int64}, 2) # Policy for evaluations
    z = randomvar(2) # Latent states
    c = datavar(Vector{Float64}, 2) # Goal prior statistics

    z_0 ~ Categorical(D) # State prior

    z_k_min = z_0
    for k=1:2
        z[k] ~ Transition(z_k_min, u[k])
        c[k] ~ GoalObservation(z[k], A) where { # Observation matrix depends on offer by secondary agent
            meta=GeneralizedMeta(x[k]), 
            pipeline=GeneralizedPipeline(vague(Categorical,16))}

        z_k_min = z[k] # Reset for next slice
    end
end

In [None]:
include("primary_agent.jl")
include("primary_environment.jl") # Environment for primary agent

(B, C, D) = constructPrimaryBCD(c)

rs = generateGoalSequence(seed, S) # Sets random seed and returns reproducible goal sequence
(reset, execute, observe) = initializePrimaryWorld(B, rs) # Define interation (Markov blanket) with the T-maze environment
(infer, act) = initializePrimaryAgent(B, C, D)
;

## Secondary Agent

In [None]:
# Variables in the secondary agent are indicated by "prime"
@model function t_maze_secondary(A_prime_s, x_prime, alpha_s)
    c_prime = datavar(Vector{Float64})

    A_prime ~ MatrixDirichlet(A_prime_s)
    c_prime ~ GoalObservation(alpha_s, A_prime) where {
                meta=GeneralizedMeta(x_prime),
                pipeline=GeneralizedPipeline()}
end

@constraints function structured(approximate::Bool)
    if approximate
        q(A_prime) :: SampleList(20)
    end
end
;

## Simulation

In [None]:
include("secondary_agent.jl")
include("secondary_environment.jl") # Environment for secondary agent represents an interaction with the primary agent

A_prime_0 = constructSecondaryPriors()

(execute_prime, observe_prime) = initializeSecondaryWorld() # Defines interaction (Markov blanket) with primary agent
(infer_prime, act_prime) = initializeSecondaryAgent(A_prime_0)

# Step through the experimental protocol
A_primes = Vector{Matrix}(undef, S) # Posterior statistics for A_p
G_primes = Vector{Vector}(undef, S) # Free energy values
a_primes = Vector{Union{Int64, Missing}}(missing, S) # Actions per time
o_primes = Vector{Union{Vector, Missing}}(missing, S) # Observations (one-hot) per time
for s = 1:S
    # Make offer at t=1
    (G_primes[s], _) = infer_prime(1, a_primes[s], o_primes[s])
         a_primes[s] = act_prime(G_primes[s])
                       execute_prime(s, a_primes[s]) # Triggers inference in primary agent
         o_primes[s] = observe_prime() # Observes cue-visit of primary agent
    
    # Learn at t=2        
    (_, A_primes[s]) = infer_prime(2, a_primes[s], o_primes[s])
end
;

## Results

In [None]:
include("visualizations.jl")
plotOffers(G_primes, a_primes, o_primes)
savefig("figures/GFE_offers")