In [19]:
using JLD
using Distributions
using LinearAlgebra
using Random
using WAV
using Plots
using ImageCore
using GraphPPL
using ReactiveMP
using Optim
using Parameters
using AIDA
import ProgressMeter

┌ Info: Precompiling AIDA [02ab3c64-7f6e-4624-92b1-4056b28faae1]
└ @ Base loading.jl:1317


In [53]:
function init_environment(; from_scratch=false)
    
    if !from_scratch
        # read context memory 
        context_priors = [JLD.load(file) for file in get_files("../src/memory/context_params")]    
        # read processed auido
        training_jlds = JLD.load.(get_files("../sound/AIDA/separated_jld/training/"))
        test_jlds = JLD.load.(get_files("../sound/AIDA/separated_jld/test/"));
        
        return context_priors, training_jlds, test_jlds
    else
        return error("Please create training and test files")
    end
    
end

init_environment (generic function with 1 method)

In [34]:
function init_agent(; preferences=nothing, contexts=nothing)
    # read preferences
    if !isnothing(preferences)
        for context in contexts
            preference_learning(preferences, context, record=true)
        end
    else
        preferences = JLD.load("../src/memory/preferences.jld")
    end
    
end

init_agent (generic function with 1 method)

In [19]:
function preference_learning(preferences, context_num; augment_data=true, tolerance=1e-12, record=false)
    data_x, data_y = augment_data ? get_learning_data(preferences, context_num) : (preferences["gains"], preferences["appraisals"])
    function f(params)
        fe = inference_flow_classifier(Float64.(data_y), [data_x[k,:] for k=1:size(data_x,1)], params)
    end
    res = optimize(f, randn(nr_params(model)), LBFGS(), Optim.Options(store_trace = true, show_trace = true, f_tol=tolerance), autodiff=:forward)
    if record JLD.save("../src/memory/preference_params_$(context_num).jld") end
    return res
end

preference_learning (generic function with 1 method)

In [None]:
listen(sound) = WAV.wavplay(sound, 8000)

In [None]:
function act(goal=1.0; params, priors_gs=[repeat([0.5], 2), diagm(ones(2))])
    m_gs, cov_gs = priors_gs[1], priors_gs[2]
    fe, gains = inference_flow_planner(m_gs, cov_gs, goal, params)
    mean(gains)
end

In [None]:
function observe(user)
    # 
end

In [57]:
# let there be environment
context_priors, training_jlds, test_jlds = init_environment();

In [35]:
# let there be agent
preferences = JLD.load("../src/memory/preferences.jld")

Dict{String, Any} with 3 entries:
  "gains"      => Any[]
  "appraisals" => Any[]
  "contexts"   => Any[]

In [None]:
init_agent()

In [60]:
# we must get sensible context change
sounds = vcat(training_jlds, test_jlds);

In [None]:
fs = 8000
gs = [1.5, 0.0]
for sound in sounds
    rz, rx = get_signal(sound["rmz"], fs), get_signal(sound["rmx"], fs)
    listen(gs[1]*rx + gs[2]*rz)
#     observe()
#     act()
end

In [7]:
stop = Ref(false)
params = Ref([ 1.0, 1.0 ])

Base.RefValue{Vector{Float64}}([1.0, 1.0])

In [8]:
thread_task = Threads.@spawn begin 
    while !stop[]
        params[] = rand(2)
        sleep(1)
    end
    println("Stopped")
end

Task (runnable) @0x0000000149ef8780

In [14]:
params[]

2-element Vector{Float64}:
 0.5435637769124955
 0.6544541146155975

In [13]:
stop[] = true

true

Stopped


In [63]:
sounds[2]

Dict{String, Any} with 12 entries:
  "rvz"      => [0.000101767 5.94421e-5 … 4.27124e-5 4.77129e-5; 0.00015319 7.8…
  "rmz"      => [0.00393807 0.00172847 … 0.00398806 0.00186795; -0.00781915 -0.…
  "rvθ"      => [0.0086706 -0.00395446 … 0.00018301 7.04482e-5; 0.00997423 -0.0…
  "rvη"      => [0.0169373 -0.00292943; 0.0140704 -0.00232894; … ; 0.018825 -0.…
  "rmη"      => [0.213052 -0.0603124; 0.199062 -0.0806925; … ; 0.269107 -0.1362…
  "rγ"       => [(41.0, 0.00213666), (41.0, 0.00273831), (41.0, 0.00219114), (4…
  "fe"       => [344.243 321.8 … -477.472 -478.108; 344.244 321.8 … -466.351 -4…
  "rmx"      => [0.000677795 0.00490101 … 0.00841169 0.000695601; 0.00368916 -0…
  "rτ"       => [(81.0, 0.0106904), (81.0, 0.0114115), (81.0, 0.0107034), (81.0…
  "filename" => "../sound/AIDA/training/babble/5dB/sp01_babble_sn5.wav"
  "rvx"      => [0.000100599 5.956e-5 … 4.27913e-5 4.77129e-5; 0.000151011 7.84…
  "rmθ"      => [0.524894 0.180105 … -0.0179563 -0.00511929; 0.51373 0.127555 ……