In [1]:
using GLMakie

In [50]:
using RxInfer, Rocket

import ReactiveMP: getrecent, messageout
import Rocket: subscribe!
import Base: show

In [119]:
# Dynamical parameters
m = 0.65 # grams
l = 0.85 # cm
b = 0.7 # friction
g = 9.81 # gravity
N = 200
# Time horizon
T = 30
O = [1.0; 0.0]
c = 0.1

#C = 4, T = 30
#C = 5, T = 5

# Time step size
Δt = 0.07;

In [99]:
# Internal dynamic model
# We asume the dynamical model is known for now
# Its interesting to extend this further and try to actually infer all parameters
function dzdt(z_t_min, u) 
    # Transition function modeling transition due to gravity, friction and engine control
    (θ, θ̇) = z_t_min
    θ̈ = 1/(m*l^2)*(-m*g*l*sin(θ) - b*θ̇ .+ u)
    Δz = [ θ̇, θ̈ ]
    z_t = z_t_min .+  Δz .* Δt
    return z_t
end

f_tanh(u)    = c * tanh(u)
f_arctanh(ū) = atanh(clamp(ū, -c+1e-3, c-1e-3) / c)

f_arctanh (generic function with 1 method)

In [106]:
# BEHOLD THE WORLD
mutable struct PendulumWorld
    real_pendulum_position :: Float64
    real_pendulum_velocity :: Float64
    
    PendulumWorld() = new(0.0, 0.0)
end

Base.show(io::IO, world::PendulumWorld) = print(io, "PendulumWorld()")

function act(world::PendulumWorld, action::Float64)
    hidden_state = (world.real_pendulum_position, world.real_pendulum_velocity)
    next_hidden_state = dzdt(hidden_state, f_tanh(action))
        
    world.real_pendulum_position = next_hidden_state[1]
    world.real_pendulum_velocity = next_hidden_state[2]
    
    return observe(world)
end

function observe(world::PendulumWorld)
    return mod(rand(Normal(world.real_pendulum_position, 0.001)), 2pi)
end

observe (generic function with 1 method)

In [101]:
@model function pendulum(T)
    # Internal model perameters
    Gamma = 1e10*diageye(2) # Transition precision
    Theta = 1e-4 # Observation variance
    cO = constvar(O) # Observation matrix
    
    m_s_t_min = datavar(Vector{Float64})
    V_s_t_min = datavar(Matrix{Float64})

    s_t_min ~ MvNormal(mean = m_s_t_min, cov = V_s_t_min)
    s_k_min = s_t_min
    
    m_u = datavar(Float64, T)
    V_u = datavar(Float64, T)
    
    m_x = datavar(Float64, T)
    V_x = datavar(Float64, T)
    
    u = randomvar(T)
    s = randomvar(T)
    x = randomvar(T)
    
    u_s = randomvar(T)
    u_constrained = randomvar(T)
    
    for k in 1:T
        u[k] ~ Normal(mean = m_u[k], var = V_u[k])
        u_constrained[k] ~ f_tanh(u[k])
        u_s[k] ~ dzdt(s_k_min, u_constrained[k])
        s[k] ~ MvNormal(mean = u_s[k], precision = Gamma)
        x[k] ~ Normal(mean = dot(cO, s[k]), variance = Theta)
        x[k] ~ Normal(mean = m_x[k], variance = V_x[k]) # goal
        s_k_min = s[k]
    end
    
    return (s, )
end

@meta function pendulum_meta()
    dzdt() -> DeltaMeta(method = Linearization())
    f_tanh() -> DeltaMeta(method = Unscented(kappa=1e-2), inverse=f_arctanh)
end


pendulum_meta (generic function with 1 method)

In [102]:
mutable struct SuperSmartRxInferAgent
    rxinfer_engine           :: Union{Nothing, RxInferenceEngine}
    the_goal_in_radians      :: Float64 
    the_goal_variance        :: Float64
    mean_control_priors      :: Vector{Float64}
    var_control_priors       :: Vector{Float64}
    mean_goal_priors         :: Vector{Float64}
    var_goal_priors          :: Vector{Float64}
    mean_current_state_prior :: Vector{Float64}
    cov_current_state_prior  :: Matrix{Float64}
    
    function SuperSmartRxInferAgent(T)
        mean_control_priors = Float64[ 0.0 for _ in 1:T ]
        var_control_priors  = Float64[ huge for _ in 1:T ]

        mean_goal_priors = Float64[ 0.0 for _ in 1:T ]
        var_goal_priors = Float64[ huge for _ in 1:T ]

        the_goal_in_radians = 3.14
        the_goal_variance   = 1e-4

        mean_current_state_prior = [ 0.0, 0.0 ]
        cov_current_state_prior  = tiny * diageye(2)

        return new(
            nothing, 
            the_goal_in_radians, 
            the_goal_variance,
            mean_control_priors,
            var_control_priors,
            mean_goal_priors,
            var_goal_priors,
            mean_current_state_prior,
            cov_current_state_prior
        )
    end
end

Base.show(io::IO, agent::SuperSmartRxInferAgent) = print(io, "SuperSmartRxInferAgent()")

# Shift a vector and put a new value at the end
function shift(vector, value)
    return (_) -> begin 
        @inbounds for i in firstindex(vector):lastindex(vector)-1
            vector[i] = vector[i + 1]
        end
        vector[end] = value
        return vector
    end
end

function Rocket.subscribe!(agent::SuperSmartRxInferAgent, datastream)
    T = length(agent.mean_control_priors)
    
    recent_action = RecentSubject(Float64)
        
    next!(recent_action, 0.0)
    
    shift_mean_control_priors = (_) -> begin 
        shift(agent.mean_control_priors, 0.0)
        agent.mean_control_priors[begin] = Rocket.getrecent(recent_action)
        return agent.mean_control_priors
    end
    
    shift_var_control_priors  = (_) -> begin 
        shift(agent.var_control_priors, huge)
        agent.var_control_priors[begin] = tiny
        return agent.var_control_priors
    end
    
    # A simple logic to update the agent's prior automatically
    autoupdates = @autoupdates begin 
        m_u = shift_mean_control_priors(q(u))
        V_u = shift_var_control_priors(q(u))
    end
        
    engine = rxinference(
        model = pendulum(T),
        meta = pendulum_meta(),
        datastream = datastream,
        autoupdates = autoupdates,
        initmarginals = (
            u = map((m, v) -> NormalMeanVariance(m, v), agent.mean_current_state_prior, agent.var_control_priors),
        ),
        autostart = false,
        returnvars = (:u, ),
    )
        
    update!(engine.model[:m_s_t_min], agent.mean_current_state_prior)
    update!(engine.model[:V_s_t_min], agent.cov_current_state_prior)
    
    # Slide logic
    slide_callback = (_) -> begin
        slide_msg_idx = 3 # This is model dependent
        (s, ) = engine.returnval
        
        (m_s_t_min, V_s_t_min) = mean_cov(getrecent(messageout(s[2], slide_msg_idx))) # Reset prior state statistics;
    
        agent.mean_current_state_prior = m_s_t_min
        agent.cov_current_state_prior = V_s_t_min
            
        update!(engine.model[:m_s_t_min], agent.mean_current_state_prior)
        update!(engine.model[:V_s_t_min], agent.cov_current_state_prior)
    end
    
    slide_subscription = subscribe!(engine.posteriors[:u], slide_callback)
        
    recent_action_subscription = subscribe!(engine.posteriors[:u], (actions) -> begin 
        next!(recent_action, mode(actions[2]))
    end)
    
    
    agent.rxinfer_engine = engine
    
    RxInfer.start(engine)
    
    return recent_action, () -> begin 
        unsubscribe!(slide_subscription)
        unsubscribe!(recent_action_subscription)
        RxInfer.stop(agent.rxinfer_engine)
        agent.rxinfer_engine = nothing
    end
end

In [103]:
# Because states of the agent are unknown to the world, we wrap them in a comprehension.
# The comprehension only returns functions for interacting with the agent.
# Internal beliefs cannot be directly observed, and interaction is only allowed through the Markov blanket
function initializeAgent(T)
    Epsilon = huge # Control prior variance
    m_u = Float64[ 0.0 for k=1:T ] # Set control priors
    V_u = Float64[ Epsilon for k=1:T ]

    x_target = 3.14 # Goal state
    Sigma = 1e-4 # Goal prior variance
    m_x = [0.0 for k=1:T]
    m_x[end] = x_target
    V_x = convert(Vector{Float64}, [huge for k=1:T])
    V_x[end] = Sigma # Set prior to reach goal at t=T

    m_s_t_min = [0.0, 0.0] # Set initial brain state prior
    V_s_t_min = tiny*diageye(2)
    
    result = nothing

    # Initialize messages and marginals dictionary

    function infer(upsilon_t::Float64, y_hat_t::Float64)
        m_u[1] = upsilon_t # Register action with the generative model
        V_u[1] = tiny # Clamp control prior to performed action

        m_x[1] = y_hat_t # Register observation with the generative model
        V_x[1] = tiny # Clamp goal prior to observation

        data = Dict(:m_u       => m_u, 
                    :V_u       => V_u, 
                    :m_x       => m_x, 
                    :V_x       => V_x,
                    :m_s_t_min => m_s_t_min,
                    :V_s_t_min => V_s_t_min)

        result = inference(
            model = pendulum(T),
            meta = pendulum_meta(),
            data = data,
        )
    end
    
    function act() 
        if result !== nothing
            return mode(result.posteriors[:u][2])[1]
        else
            return 0.0
        end
    end

    function slide(slide_msg_idx = 3)
        (s, ) = result.returnval
        (m_s_t_min, V_s_t_min) = mean_cov(getrecent(messageout(s[2], slide_msg_idx))) # Reset prior state statistics;

        m_u = circshift(m_u, -1)
        m_u[end] = 0.0
        V_u = circshift(V_u, -1)
        V_u[end] = Epsilon

        m_x = circshift(m_x, -1)
        m_x[end] = x_target
        V_x = circshift(V_x, -1)
        V_x[end] = Sigma
    end

    return (infer, act, slide)    
end

initializeAgent (generic function with 1 method)

In [104]:
function animstep!(rod,Θ)
    
    rod[] = [Point2f(0, 0), Point2f(l*sin(Θ), -l*cos(Θ))]
    balls[] = [Point2f(l*sin(Θ), -l*cos(Θ))]

end

animstep! (generic function with 1 method)

In [105]:
x1 = 0.5
y1 = 0.5
rod   = Observable([Point2f(0, 0), Point2f(x1, y1)])
balls = Observable([Point2f(x1, y1)])
fig = Figure(); display(fig)
ax = Axis(fig[1,1])

lines!(ax, rod; linewidth = 4, color = :blue)
scatter!(ax, balls; marker = 'o', strokewidth = 2,
    strokecolor = :black,
    color = :black, markersize = [30]
)
xlims!(ax, -1.5l, 1.5l)
ylims!(ax, -1.5l, 1.5l)
ax.title = "Pendulum"
ax.aspect = DataAspect()

world = PendulumWorld()

agent = SuperSmartRxInferAgent(T)

ys = Subject(Float64)

# subscribe!(ys, logger())

observations = combineLatest(ys, of(tiny)) |> map(Tuple{Vector{Float64}, Vector{Float64}}, ((cx, cV),) -> begin
    shift(agent.mean_goal_priors, agent.the_goal_in_radians)(nothing)
    shift(agent.var_goal_priors, agent.the_goal_variance)(nothing)
    agent.mean_goal_priors[begin] = cx
    agent.var_goal_priors[begin] = cV
    return (agent.mean_goal_priors, agent.var_goal_priors)
end)

datastream = labeled(Val((:m_x, :V_x)), observations) |> async()

action, subscription = subscribe!(agent, datastream);

a2 = Rocket.getrecent(action) # Evoke an action from the agent
y2 = act(world, a2) # The action influences hidden external states

next!(ys, y2)

println(a2, " ", y2, " ", Rocket.getrecent(action))

animstep!(rod,y2)
# sleep(0.01)

# The run button is actually pretty simple, we'll add it below the plot
run = Button(fig[2,1]; label = "run", tellwidth = false)
stop = Button(fig[3,1]; label = "stop", tellwidth = false)
# This button will start/stop an animation. It's actually surprisingly
# simple to do this. The magic code is:
isrunning = Observable(true)

on(run.clicks) do clicks
    isrunning[] = true
end

on(stop.clicks) do clicks
    isrunning[] = false
    unsubscribe!(subscription)
end

on(run.clicks) do clicks
    @async begin 
        try 
            # println("asd")
            iters = 1
            while isrunning[] && iters < 10
                global y2
                global a2

                isopen(fig.scene) || break # ensures computations stop if closed window

                a2 = Rocket.getrecent(action) # Evoke an action from the agent
                y2 = act(world, a2) # The action influences hidden external states
                next!(ys, y2) 
                animstep!(rod,y2)
                sleep(0.01)
            end
        catch err
            bt = catch_backtrace()
            println()
            showerror(stderr, err, bt)
        end
    end
end

# on(ax.scene.events.mousebutton) do mpos
#     #global a2
#     global y2
    
#     if ispressed(ax.scene, Mouse.left)
#        pos = to_world(ax.scene, Point2f(ax.scene.events.mouseposition[]))
#        #print(pos)
#        y2 = atan(pos[2],pos[1]) + π/2 
#        balls[] = [Point2f(pos[1]-1, pos[2]-0.5)]
#        rod[] = [Point2f(0, 0), Point2f(pos[1]-1, pos[2]-0.5)]
        
#    end
#    return
# end

0.0 0.0005612499457319316 0.0


ObserverFunction defined at In[105]:65 operating on Observable{Any}(0)

**Credits** The original code is written by Sepideh Adamiat. Adapted by Dmitry Bagaev