In [5]:
cd("/home/jshrager/active_learning/alsims")

In [67]:
using Pkg; Pkg.activate("ALSims.jl")
using ProgressMeter
using ALSims
using PyPlot; const plt = PyPlot
using Statistics
using Printf
using Serialization

In [339]:
# For convenience, policies is globally set, but it gets stored in each result, so that they can be
# recovered in context, even if the global value changes
# policies = Dict(:ptw => ALSimAs.ptw_policy, :ts => ALSims.ts_policy, :ucb => ALSims.ucb_policy, :ura => ALSims.ura_policy)
policies = Dict(:ts => ALSims.ts_policy, :ucb => ALSims.ucb_policy)

# Results dicts get push here; plotrun defaults to taking the last (first) one
# !!!!!!!!!! DON'T DO THIS UNLESS YOU WANT TO RESET ALL RESULTS!!!!!!!!!!!!!
G_results = []
# G_results will get auto-loaded whne you run, if they don't already exist

function listresults()
    for i = 1:length(G_results)
        r = G_results[i]
        @printf("[%s]:int=%s,tx=%s,iters=%s,bm=%s,snr=%s,pr=%s\n",i,r["n_int"],r["n_tx"],r["n_iters"],r["n_bm"],r["snr"],r["n_pt"])
    end
end

listresults (generic function with 1 method)

In [340]:
function run(;n_tx,int_flag=false,n_iters=100,n_pt=500,snr=10,n_bm=0)
    # Reload global results inly if they are empty
    if G_results == []
        G_results = deserialize("/home/jshrager/al_results.serialized")
    end
    n_int=int_flag ? n_tx : 0
    int_ind = Vector{NTuple{2, Int}}()
    n_x = 1 + n_bm + n_tx + n_int
    X_bm = rand([0, 1], n_pt, n_bm)
    X_tx = rand([0, 1], n_pt, n_tx);
    true_model = GaussianGenerativeModel(snr * randn(n_x), 1.0)
    learning_model = GaussianLearningModel(n_x)

    #This give you the raw data back, although we usually don't want this
    #sim_data = simulate_patients(X_bm, n_tx, true_model, learning_model; bm_tx_int_ind = int_ind, policies = policies)

    # This provides the regrets against the true model: 
    sim_function(true_model) = simulate_patients(X_bm, n_tx, true_model, learning_model; bm_tx_int_ind = int_ind, policies)

    true_models = [GaussianGenerativeModel(snr * randn(n_x), 1.0) for i in 1:n_iter]
    sim_data_array = @showprogress map(sim_function, true_models);
    result = Dict("policies"=>policies,"n_tx" => n_tx, "n_int" => n_int, "n_iters" => n_iters, "n_pt" => n_pt, "snr" => snr, "n_bm" => n_bm, "sim_data_array" => sim_data_array)
    
    # Store this result, save, and display
    pushfirst!(results,G_results)
    serialize("/home/jshrager/al_results.serialized", G_results)
    listresults()
    plotrun()
end


run (generic function with 1 method)

In [341]:
function plotrun(;result=G_results[1],ylimit=0,labelpos=false)
    # UUU FFF There's probably a better way to do this:
    n_int = result["n_int"]
    n_tx = result["n_tx"]
    n_iters = result["n_iters"]
    n_bm = result["n_bm"]
    n_pt = result["n_pt"]
    snr = result["snr"]
    policies = result["policies"]
    sim_data_array = result["sim_data_array"]
    for (i, key) in enumerate([key for (key,y) in policies])
        y = hcat([sim_data_array[i][key][:regrets] for i in 1:n_iter]...)
        # y = cumsum(y, dims=1)
        y_mean = reshape(mean(y, dims=2), n_pt)
        y_std = reshape(std(y, dims=2), n_pt)
        x = collect(1:n_pt)
        c = "C$i"
        plt.plot(x, y_mean, color=c, label=String(key))
        plt.fill_between(x, y_mean .- y_std/sqrt(n_iter), y_mean .+ y_std/sqrt(n_iter), alpha=0.3, color=c)
        plt.ylabel("Regret")
        plt.xlabel("Iteration")
    end
    axes = plt.gca()
    if ylimit != 0
        axes.set_ylim(ylimit)
    end
    if labelpos == false
        ypos=axes.get_ylim()[2]
        labelpos = [10,ypos-floor(ypos/10)]
    end
    context_label = @sprintf("(+-stderr) int=%s,tx=%s,iters=%s,bm=%s,snr=%s",n_int,n_tx,n_iters,n_bm,snr)
    plt.text(labelpos..., context_label)
    plt.legend()
end

plotrun (generic function with 2 methods)

In [342]:
run(n_tx=5,n_iters=10)

LoadError: UndefVarError: G_results not defined

In [343]:
listresults()