# User preference learning

In [1]:
using Revise

using ReactiveMP
using Rocket
using GraphPPL

using Optim
using LinearAlgebra
using Random
using PyPlot
using StatsFuns: normcdf
using ForwardDiff
using BenchmarkTools


include("../src/environment/user.jl");

## Flow model

In [2]:
# specify flow model
flow_model = FlowModel(2,
    (
        AdditiveCouplingLayer(RadialFlow()), # defaults to AdditiveCouplingLayer(PlanarFlow(); permute=true)
        AdditiveCouplingLayer(RadialFlow()),
        AdditiveCouplingLayer(RadialFlow(); permute=false)
    )
);

## Inference model optimization

In [3]:
@model function flow_classifier(nr_samples::Int64, model::FlowModel, params)
    
    # initialize variables
    x_lat  = randomvar(nr_samples)
    y_lat1 = randomvar(nr_samples)
    y_lat2 = randomvar(nr_samples)
    y      = datavar(Float64, nr_samples)
    x      = datavar(Vector{Float64}, nr_samples)

    # compile flow model
    meta  = FlowMeta(compile(model, params)) # default: FlowMeta(model, Linearization())

    # specify observations
    for k = 1:nr_samples

        # specify latent state
        x_lat[k] ~ MvNormalMeanPrecision(x[k], 1e3*diagm(ones(2)))

        # specify transformed latent value
        y_lat1[k] ~ Flow(x_lat[k]) where { meta = meta }
        y_lat2[k] ~ dot(y_lat1[k], [1, 1])

        # specify observations
        y[k] ~ Probit(y_lat2[k]) # default: where { pipeline = RequireInbound(in = NormalMeanPrecision(0, 1.0)) }

    end

    # return variables
    return x, y

end;

In [4]:
function inference_flow_classifier(data_y::Array{Float64,1}, data_x::Array{Array{Float64,1},1}, flow_model::FlowModel, params; nr_its=5)
    
    # fetch number of samples
    nr_samples = length(data_y)

    # define model
    model, (x, y) = flow_classifier(nr_samples, flow_model, params)

    # initialize free energy
    fe = ScoreActor(eltype(params))

    # subscribe
    fe_sub = subscribe!(score(eltype(params), BetheFreeEnergy(), model), fe)

    # update y and x according to observations (i.e. perform inference)
    for k=1:nr_its
        ReactiveMP.update!(y, data_y)
        ReactiveMP.update!(x, data_x)
    end

    # unsubscribe
    unsubscribe!(fe_sub)
    
    # return the marginal values
    return getvalues(fe)[end]

end;

In [5]:
function calculate_parameters(params, model, data_x, data_y)

    function f(params)
        fe = inference_flow_classifier(data_y, data_x, model, params)
        return fe
    end

    optimizer = Adam(params; λ=1e-1)

    ∇ = zeros(nr_params(model))

    for it = 1:10000

        # backward pass
        ForwardDiff.gradient!(∇, f, optimizer.x)

        # gradient update
        ReactiveMP.update!(optimizer, ∇)

    end

    return f(optimizer.x), optimizer.x

end

calculate_parameters (generic function with 1 method)

In [6]:
# function repeat_calculate_parameters(model, data_x, data_y)
#     try
#         return calculate_parameters(randn(nr_params(model)), model, data_x, data_y)
#     catch e
#         #println("   ERROR: calculate_parameters() failed")
#         #return repeat_calculate_parameters(model, data_x, data_y)
#         println(e)
#     end
# end

In [7]:
function calculate_parameters_tries(model, data_x, data_y; nr_tries=1)
    params = Vector{Vector{Float64}}(undef, nr_tries)
    fe = Vector{Float64}(undef, nr_tries)
    Threads.@threads for k = 1:nr_tries
        fe[k], params[k] = calculate_parameters(randn(nr_params(model)), model, data_x, data_y)
    end
    println(fe)
    return params[argmin(fe)]
end

calculate_parameters_tries (generic function with 1 method)

## Inference input estimation

In [8]:
struct PointMassFormConstraint2{P}
    point :: P   
end

ReactiveMP.default_form_check_strategy(::PointMassFormConstraint2) = FormConstraintCheckLast()

ReactiveMP.is_point_mass_form_constraint(::PointMassFormConstraint2) = true

function ReactiveMP.constrain_form(pmconstraint::PointMassFormConstraint2, message::Message) 
    is_clamped = ReactiveMP.is_clamped(message)
    is_initial = ReactiveMP.is_initial(message)
    return Message(PointMass(pmconstraint.point), is_clamped, is_initial)
end

In [9]:
@model function flow_classifier_input(input, model, params)
    
    # initialize variables
    x_lat  = randomvar()
    y_lat1 = randomvar()
    y_lat2 = randomvar()
    xprior = randomvar() where { form_constraint = PointMassFormConstraint2(input)}
    y = datavar(Float64)

    # specify model
    meta  = FlowMeta(compile(model, params))

    # specify prior on weights
    xprior ~ MvNormalMeanPrecision([0.0,0.0], 0.1*diagm(ones(2))) where { q = MeanField() }

    # specify latent state
    x_lat ~ MvNormalMeanPrecision(xprior, 1e3*diagm(ones(2))) where { q = MeanField() }

    # specify transformed latent value
    y_lat1 ~ Flow(x_lat) where { meta = meta }
    y_lat2 ~ dot(y_lat1, [1, 1])

    # specify observations
    y ~ Probit(y_lat2) # default where { pipeline = RequireInbound(in = NormalMeanPrecision(0, 1.0)) }

    # return variables
    return x_lat, y_lat1, y_lat2, y

end;

In [10]:
function inference_flow_classifier_input(input, flow_model, params; nr_its=5)

    # define model
    model, (x_lat, y_lat1, y_lat2, y) = flow_classifier_input(input, flow_model, params)

    # initialize free energy
    fe = ScoreActor(eltype(input))
    
    # subscribe
    fe_sub = subscribe!(score(eltype(input), BetheFreeEnergy(), model), fe)

    setmarginal!(x_lat, vague(MvNormalMeanPrecision, 2))
    
    # update y and x according to observations (i.e. perform inference)
    for k = 1:nr_its
        ReactiveMP.update!(y, 1.0)
    end

    # unsubscribe
    unsubscribe!(fe_sub)
    
    # return the marginal values
    return getvalues(fe)[end]

end;

In [11]:
function calculate_input(params, model)

    function f_input(input)
        fe = inference_flow_classifier_input(input, model, params)
        return fe
    end

    res = optimize(f_input, -0.5.*ones(2), 0.5.*ones(2), rand(2).-0.5, Fminbox(LBFGS()), Optim.Options(iterations = 500, store_trace = false, show_trace = false), autodiff=:forward)

    return Optim.minimum(res), Optim.minimizer(res)
    
end

calculate_input (generic function with 1 method)

In [12]:
function repeat_calculate_input(params, model)
    try
        return calculate_input(params, model)
    catch 
        # println("   ERROR: calculate_input() failed")
        return repeat_calculate_input(params, model)
        # println(e)
    end
end

repeat_calculate_input (generic function with 1 method)

In [13]:
function calculate_input_tries(params, model; nr_tries=1)
    input = Vector{Vector{Float64}}(undef, nr_tries)
    fe = Vector{Float64}(undef, nr_tries)
    Threads.@threads for k = 1:nr_tries
        fe[k], input[k] = repeat_calculate_input(params, model)
    end
    return input[argmin(fe)]
end

calculate_input_tries (generic function with 1 method)

## Create plot

In [14]:
function plot_figure(model, data_x, data_y, it)
    classification_map = map((x) -> normcdf(dot([1,1],x)), map((x) -> forward(model, [x...]), collect(Iterators.product(-0.5:0.01:0.5, -0.5:0.01:0.5))))
    fig, ax = plt.subplots(ncols = 1)
    im = ax.contourf(repeat(0:0.01:1, 1, 101), repeat(0:0.01:1, 1, 101)', classification_map)
    ax.scatter(hcat(data_x...)[1,:].+0.5, hcat(data_x...)[2,:].+0.5, c="white")
    ax.scatter(hcat(data_x...)[1,:].+0.5, hcat(data_x...)[2,:].+0.5, c=data_y, marker="x", vmin=0, vmax=1)
    plt.colorbar(im, ax=ax)
    ax.grid()
    ax.set_xlabel("gain 1"), ax.set_ylabel("gain 2")
    ax.set_title(string("Iteration ", it));
    plt.savefig(string("exports/NF_preferences_continuous_", it, ".eps"))
    plt.savefig(string("exports/NF_preferences_continuous_", it, ".png"))
    plt.close("all")
end

plot_figure (generic function with 1 method)

## User preference function

In [15]:
function learn_user_preferences(model; jitter=1, μ=[0.8, 0.2], a=1, b=1, c=25, d=-0.4, nr_tries_params=1, nr_tries_input=1, nr_its=20)

    # set flags
    DONE = false
    it = 1

    # select random data point (initial)
    optimum = rand(2)
    r = generate_user_response(optimum; μ=μ, a=a, b=b, c=c, d=d, binary=false)
    data_y = [r*ones(jitter)...]
    data_x = [[optimum.-0.5 + randn(2)*0.01 for k=1:jitter]...]
    optimum = rand(2)
    r = generate_user_response(optimum; μ=μ, a=a, b=b, c=c, d=d, binary=false)
    data_y = [data_y..., r*ones(jitter)...]
    data_x = [data_x..., [optimum.-0.5 + randn(2)*0.01 for k=1:jitter]...]

    # preference learning loop
    while DONE == false && it <= nr_its

        # calculate parameters
        params = calculate_parameters_tries(flow_model, data_x, data_y; nr_tries=nr_tries_params)
        inferred_model = compile(flow_model, params)

        # create plot
        plot_figure(inferred_model, data_x, data_y, it)

        # propose preferences
        optimum = calculate_input_tries(params, flow_model; nr_tries=nr_tries_input)

        # updata data
        r = generate_user_response(optimum.+0.5; μ=μ, a=a, b=b, c=c, d=d, binary=false)
        data_y = [data_y..., r*ones(jitter)...]
        data_x = [data_x..., [optimum + randn(2)*0.01 for k=1:jitter]...]

        # print summary
        println(string("Iteration ", it, ":"))
        println(string("    proposal: ", optimum))
        println(string("    response: ", r))

        # check if done
        if r == 1
            DONE = true
        end

        # update iteration number
        it += 1

    end

end

learn_user_preferences (generic function with 1 method)

In [16]:
learn_user_preferences(flow_model; nr_tries_params=3, nr_tries_input=10, nr_its=20)

[0.017429278680207716, 0.017535393644127595, 0.017555696822892486]
Iteration 1:
    proposal: [0.499999998759431, 0.4999999978435874]
    response: 4.605371001232599e-8
[0.018173378644686267, 0.022467656486945486, 0.014434504535081771]
Iteration 2:
    proposal: [-0.23993509753942596, -0.49994579224357866]
    response: 1.603191860826699e-5
[-0.42045443474683, 0.02178659241348413, 0.03771044250232336]
Iteration 3:
    proposal: [0.4999999561962362, 0.49999967769152726]
    response: 4.605405759545056e-8
[0.12887668289231868, 0.01836449351364422, 0.018397960641564737]
Iteration 4:
    proposal: [0.49999937290057656, -0.2947892925942674]
    response: 0.09928577711860842
[0.8988123473613427, 0.36802850214213834, 0.35106952415829795]
Iteration 5:
    proposal: [0.31061591141214706, 0.49997571942632707]
    response: 8.20856263221042e-8
[0.3589840358117513, 0.37283945508360716, 0.3584722470594457]
Iteration 6:
    proposal: [0.49999999999999994, -0.02549285692530829]
    response: 0.004120

In [None]:
# params = randn(nr_params(flow_model))
# optimum = rand(2)
# jitter = 1
# r = generate_user_response(optimum; binary=false)
# data_y = [r*ones(jitter)...]
# data_x = [[optimum + randn(2)*0.01 for k=1:jitter]...]
# optimum = rand(2)
# r = generate_user_response(optimum; binary=false)
# data_y = [data_y..., r*ones(jitter)...]
# data_x = [data_x..., [optimum + randn(2)*0.01 for k=1:jitter]...]

In [None]:
# 20-23s (compiled)
# 16s (threads)

In [None]:
# _, ax = plt.subplots(ncols=1)
# classification_map = map((x) -> generate_user_response([x...]; μ=[0.8, 0.2], binary=false), collect(Iterators.product(0:0.01:1, 0:0.01:1)))
# im = ax.contourf(repeat(0:0.01:1, 1, 101), repeat(0:0.01:1, 1, 101)', classification_map)
# plt.colorbar(im, ax=ax)
# ax.grid()
# ax.set_xlabel("gain 1"), ax.set_ylabel("gain 2")
# ax.set_title("Actual user response");

In [None]:
# params = randn(nr_params(flow_model))
# optimum = rand(2)
# jitter = 1
# r = generate_user_response(optimum; binary=false)
# data_y = [r*ones(jitter)...]
# data_x = [[optimum + randn(2)*0.01 for k=1:jitter]...]
# optimum = rand(2)
# r = generate_user_response(optimum; binary=false)
# data_y = [data_y..., r*ones(jitter)...]
# data_x = [data_x..., [optimum + randn(2)*0.01 for k=1:jitter]...]

In [None]:
# calculate_parameters(params, flow_model, data_x, data_y)

In [None]:
# @code_warntype calculate_parameters(params, flow_model, data_x, data_y)

In [None]:
# @benchmark calculate_parameters($params, $flow_model, $data_x, $data_y)

In [None]:
# calculate_parameters(params, flow_model, data_x, data_y)

In [None]:
# using Profile

# @profile calculate_parameters(params, flow_model, data_x, data_y)

In [None]:
# Profile.print()

In [None]:
# # type stable
# BenchmarkTools.Trial: 
#   memory estimate:  243.89 MiB
#   allocs estimate:  3022996
#   --------------
#   minimum time:     623.854 ms (4.87% GC)
#   median time:      644.346 ms (4.91% GC)
#   mean time:        656.192 ms (4.89% GC)
#   maximum time:     713.424 ms (4.75% GC)
#   --------------
#   samples:          8
#   evals/sample:     1

In [None]:
# # initial
# BenchmarkTools.Trial: 
#   memory estimate:  244.15 MiB
#   allocs estimate:  3035514
#   --------------
#   minimum time:     751.122 ms (4.73% GC)
#   median time:      825.649 ms (4.95% GC)
#   mean time:        849.803 ms (4.95% GC)
#   maximum time:     1.027 s (4.25% GC)
#   --------------
#   samples:          7
#   evals/sample:     1