In [1]:
using GLMakie

In [2]:
using RxInfer, Rocket

import ReactiveMP: getrecent, messageout
import Rocket: subscribe!
import Base: show

In [50]:
# Dynamical parameters
m = 0.65 # grams
l = 0.85 # cm
b = 0.6 # friction
g = 9.81 # gravity
N = 200
# Time horizon
T = 20
c = 6.0
S = 1e-2 # World noise variance

#C = 4, T = 30
#C = 5, T = 5

# Time step size
# Δt = 0.007;
Δt = 1 / 60;

In [51]:
# Internal dynamic model
# We asume the dynamical model is known for now
# Its interesting to extend this further and try to actually infer all parameters
function dzdt(z_t_min, u) 
    # Transition function modeling transition due to gravity, friction and engine control
    (θ, θ̇) = z_t_min
    θ̈ = 1/(m*l^2)*(-m*g*l*sin(θ) - b*θ̇ .+ u)
    Δz = [ θ̇, θ̈ ]
    z_t = z_t_min .+  Δz .* Δt
    return z_t
end

f_tanh(u)    = c * tanh(u)
f_arctanh(ū) = atanh(clamp(ū, -c+1e-3, c-1e-3) / c)

f_arctanh (generic function with 1 method)

In [52]:
# BEHOLD THE WORLD
mutable struct PendulumWorld
    real_pendulum_position :: Float64
    real_pendulum_velocity :: Float64
    next_registered_action :: Float64
    real_observations      :: Any
    noisy_observations     :: Any
    
    PendulumWorld() = new(0.0, 0.0, 0.0, RecentSubject(Float64), RecentSubject(Float64))
end

Base.show(io::IO, world::PendulumWorld) = print(io, "PendulumWorld()")
    
function register_next_action(world::PendulumWorld, action)
    world.next_registered_action = action
end

# `tick` function is used to move the state of the world further and is independed from any agent
# An agent can only `register` a new action in between with the `register_next_action`
function tick(world::PendulumWorld)
    hidden_state = (world.real_pendulum_position, world.real_pendulum_velocity)
    next_hidden_state = dzdt(hidden_state, f_tanh(world.next_registered_action))
            
    world.next_registered_action = 0.0
    world.real_pendulum_position = next_hidden_state[1]
    world.real_pendulum_velocity = next_hidden_state[2]
        
    noisy_observation = mod(rand(NormalMeanVariance(world.real_pendulum_position, S)), 2pi)
        
    next!(world.real_observations, world.real_pendulum_position)
    next!(world.noisy_observations, noisy_observation)
end

tick (generic function with 1 method)

In [53]:
@model function pendulum(T)
    # Internal model parameters
    Gamma = constvar(1e10 * diageye(2)) # Transition precision
    cO = constvar([1.0; 0.0]) # Observation matrix
    
    # Previous state prior
    m_s_t_min = datavar(Vector{Float64})
    V_s_t_min = datavar(Matrix{Float64})
    
    # Previous action prior
    m_u_t_min = datavar(Float64)
    v_u_t_min = datavar(Float64)
    
    # Current observation
    x_t = datavar(Float64)
    
    # Future control priors
    m_u = datavar(Float64, T)
    V_u = datavar(Float64, T)
    
    # Future goal priors
    m_x = datavar(Float64, T)
    V_x = datavar(Float64, T)
    
    u   = randomvar(T) # Future actions
    u_s = randomvar(T) # Future deterministic states
    s   = randomvar(T) # Future states with uncertainty
    x   = randomvar(T) # Future observations

    s_t_min ~ MvNormal(mean = m_s_t_min, cov = V_s_t_min) # Prior for previous state
    u_t_min ~ Normal(mean = m_u_t_min, var = v_u_t_min)   # Prior for previous action
    u_s_min ~ dzdt(s_t_min, f_tanh(u_t_min))              # Deterministic state transition function
    s_t     ~ MvNormal(mean = u_s_min, precision = Gamma) # Transition uncertainty
    x_t     ~ Normal(mean = dot(cO, s_t), variance = S)   # Observational function
    
    s_k_min = s_t
    
    for k in 1:T
        u[k]    ~ Normal(mean = m_u[k], var = V_u[k])         #
        u_s[k]  ~ dzdt(s_k_min, f_tanh(u[k]))
        s[k]    ~ MvNormal(mean = u_s[k], precision = Gamma)
        x[k]    ~ Normal(mean = dot(cO, s[k]), variance = S)
        x[k]    ~ Normal(mean = m_x[k], variance = V_x[k]) 
        s_k_min = s[k]
    end
    
    return (s, )
end

@meta function pendulum_meta()
    dzdt() -> DeltaMeta(method = Linearization())
    f_tanh() -> DeltaMeta(method = Unscented(kappa=1e-2), inverse=f_arctanh)
end


pendulum_meta (generic function with 1 method)

In [62]:
mutable struct SuperSmartRxInferAgent
    rxinfer_engine           :: Union{Nothing, RxInferenceEngine}
    the_goal_in_radians      :: Float64 
    the_goal_variance        :: Float64
    mean_control_priors      :: Vector{Float64}
    var_control_priors       :: Vector{Float64}
    mean_goal_priors         :: Vector{Float64}
    var_goal_priors          :: Vector{Float64}
    mean_current_state_prior :: Vector{Float64}
    cov_current_state_prior  :: Matrix{Float64}
    
    function SuperSmartRxInferAgent(T)
        mean_control_priors = Float64[ 0.0 for _ in 1:T ]
        var_control_priors  = Float64[ huge for _ in 1:T ]

        mean_goal_priors = Float64[ 0.0 for _ in 1:T ]
        var_goal_priors = Float64[ huge for _ in 1:T ]

        the_goal_in_radians = 3.14
        the_goal_variance   = 1e-4

        mean_current_state_prior = [ 0.0, 0.0 ]
        cov_current_state_prior  = tiny * diageye(2)

        return new(
            nothing, 
            the_goal_in_radians, 
            the_goal_variance,
            mean_control_priors,
            var_control_priors,
            mean_goal_priors,
            var_goal_priors,
            mean_current_state_prior,
            cov_current_state_prior
        )
    end
end

Base.show(io::IO, agent::SuperSmartRxInferAgent) = print(io, "SuperSmartRxInferAgent()")

# Shift a vector and put a new value at the end
function shift(vector, value)
    return (_) -> begin 
        @inbounds for i in firstindex(vector):lastindex(vector)-1
            vector[i] = vector[i + 1]
        end
        vector[end] = value
        return vector
    end
end

function Rocket.subscribe!(agent::SuperSmartRxInferAgent, datastream)
    T = length(agent.mean_control_priors)
    
    recent_action = RecentSubject(Float64)
        
    next!(recent_action, 0.0)
    
    shift_mean_control_priors = shift(agent.mean_control_priors, 0.0)
    shift_var_control_priors  = shift(agent.var_control_priors, huge)
    shift_mean_goal_priors    = shift(agent.mean_goal_priors, agent.the_goal_in_radians)
    shift_var_goal_priors     = shift(agent.var_goal_priors, agent.the_goal_variance)
    pick_first_action         = (actions) -> mean_var(first(actions))
    
    # A simple logic to update the agent's prior automatically
    autoupdates = @autoupdates begin 
        m_u = shift_mean_control_priors(q(u))
        V_u = shift_var_control_priors(q(u))
        m_x = shift_mean_goal_priors(q(x))
        V_x = shift_var_goal_priors(q(x))
        m_u_t_min, v_u_t_min = pick_first_action(q(u))
    end
        
    initial_forces = map(agent.mean_current_state_prior, agent.var_control_priors) do m, v
         return NormalMeanVariance(m, v)
    end
        
    engine = rxinference(
        model = pendulum(T),
        meta = pendulum_meta(),
        datastream = datastream,
        autoupdates = autoupdates,
        initmarginals = (u = initial_forces,),
        autostart = false,
        returnvars = (:u, ),
    )
        
    update!(engine.model[:m_s_t_min], agent.mean_current_state_prior)
    update!(engine.model[:V_s_t_min], agent.cov_current_state_prior)
    
    # Slide logic, a bit tricky, we need to create a proper API for this
    slide_callback = (_) -> begin
        slide_msg_idx = 3        # This is model dependent
        (s, ) = engine.returnval # Retrieve a reference to the `states` posteriors
        predictive_message = getrecent(messageout(s[1], slide_msg_idx)) # Get a predictive message
            
        (m_s_t_min, V_s_t_min) = mean_cov(predictive_message) 
    
        agent.mean_current_state_prior = m_s_t_min
        agent.cov_current_state_prior = V_s_t_min
            
        update!(engine.model[:m_s_t_min], agent.mean_current_state_prior)
        update!(engine.model[:V_s_t_min], agent.cov_current_state_prior)
    end
    
    slide_subscription = subscribe!(engine.posteriors[:u], slide_callback)
        
    recent_action_subscription = subscribe!(engine.posteriors[:u], (actions) -> begin 
        next!(recent_action, mode(first(actions)))
    end)
    
    agent.rxinfer_engine = engine
    
    RxInfer.start(engine)
    
    return recent_action, () -> begin 
        unsubscribe!(slide_subscription)
        unsubscribe!(recent_action_subscription)
        RxInfer.stop(agent.rxinfer_engine)
        agent.rxinfer_engine = nothing
    end
end

In [63]:
function animstep!(rod,Θ)
    
    rod[] = [Point2f(0, 0), Point2f(l*sin(Θ), -l*cos(Θ))]
    balls[] = [Point2f(l*sin(Θ), -l*cos(Θ))]

end

animstep! (generic function with 1 method)

In [64]:
x1 = 0.5
y1 = 0.5
rod   = Observable([Point2f(0, 0), Point2f(x1, y1)])
balls = Observable([Point2f(x1, y1)])
fig = Figure(); display(fig)
ax = Axis(fig[1,1])

lines!(ax, rod; linewidth = 4, color = :blue)
scatter!(ax, balls; marker = 'o', strokewidth = 2,
    strokecolor = :black,
    color = :black, markersize = [30]
)
xlims!(ax, -1.5l, 1.5l)
ylims!(ax, -1.5l, 1.5l)
ax.title = "Pendulum"
ax.aspect = DataAspect()

world = PendulumWorld()

agent = SuperSmartRxInferAgent(T)

datastream = labeled(Val((:x_t, )), combineLatest(world.noisy_observations))

action, subscription = subscribe!(agent, datastream);

action_subscription = subscribe!(action, (a) -> register_next_action(world, a))

tick(world)

animstep!(rod, Rocket.getrecent(world.real_observations))
# sleep(0.01)

# The run button is actually pretty simple, we'll add it below the plot
run = Button(fig[2,1]; label = "run", tellwidth = false)
stop = Button(fig[3,1]; label = "stop", tellwidth = false)
# This button will start/stop an animation. It's actually surprisingly
# simple to do this. The magic code is:
isrunning = Observable(true)

on(run.clicks) do clicks
    isrunning[] = true
end

on(stop.clicks) do clicks
    isrunning[] = false
    unsubscribe!(subscription)
    unsubscribe!(action_subscription)
end

on(run.clicks) do clicks
    @async begin 
        try 
            while isrunning[]
                global y2
                global a2

                isopen(fig.scene) || break # ensures computations stop if closed window

                # a2 = Rocket.getrecent(action) # Evoke an action from the agent
                # y2 = act(world, a2) # The action influences hidden external states
                # next!(ys, y2) 
                tick(world)
                animstep!(rod, Rocket.getrecent(world.real_observations))
                sleep(Δt)
            end
        catch err
            bt = catch_backtrace()
            println()
            showerror(stderr, err, bt)
        end
        unsubscribe!(subscription)
        unsubscribe!(action_subscription)
    end
end

# on(ax.scene.events.mousebutton) do mpos
#     #global a2
#     global y2
    
#     if ispressed(ax.scene, Mouse.left)
#        pos = to_world(ax.scene, Point2f(ax.scene.events.mouseposition[]))
#        #print(pos)
#        y2 = atan(pos[2],pos[1]) + π/2 
#        balls[] = [Point2f(pos[1]-1, pos[2]-0.5)]
#        rod[] = [Point2f(0, 0), Point2f(pos[1]-1, pos[2]-0.5)]
        
#    end
#    return
# end

ObserverFunction defined at In[64]:51 operating on Observable{Any}(0)

**Credits** The original code is written by Sepideh Adamiat. Adapted by Dmitry Bagaev