In [2]:
using StatsBase
using JLD
using DataFrames
include("./model.jl")

run_model (generic function with 2 methods)

In [5]:
function get_block(true_planet)
    block =[]
    for p in true_planet
        if p < 20
            append!(block,1)
        elseif p < 40
            append!(block,2)
        elseif p < 60
            append!(block,3)
        elseif p < 80
            append!(block,4)
        elseif p < 100
            append!(block,5)
        end
    end
    return block
end


function sub_params(model, sub_num)
    sub_data = get_sub_data(sub_num)
    if cmp(model,"adaptive_discount")
        d=load(string("fit_params/adaptive_discount_10_10_3/sub",string(sub_num),".jld"))
        params = d["res"]
        num_particles = 1

        b = crp_adaptiveDiscount(sub_data,params,num_particles);
    elseif cmp(model,"mvt")
        d=load(string("fit_params/MVT_learn_1_20_3/sub",string(sub_num),".jld"))
        params = d["res"]

        b = MVT_learn(sub_data,params);
    elseif cmp(model,"td")
        d=load(string("fit_params/td_1_20_1_3/sub",string(sub_num),".jld"))
        params = d["res"]

        b = TD(sub_data,params);
    end

    opt_prt = optimal_policy(b);
    diff = b.prt - opt_prt;
    
    df = DataFrame(Dict("true_planet"=> b.true_planet,"galaxy"=> b.galaxy,"block"=> get_block(b.true_planet),"prt"=>b.prt,
        "opt_prt"=>opt_prt, "diff" => diff))
    
    df_params = DataFrame(Dict("sub_num"=> sub_num,"alpha"=> params[1],"gamma_base"=> params[2],"gamma_coef"=> params[3]))

    gdf = groupby(df, :galaxy)
    prt_avg = combine(gdf, :prt => mean)
    
    insertcols!(df,       # DataFrame to be changed
    1,                # insert as column 1
    :sub_num => ones(size(df)[1])*sub_num,   # populate as "Day" with 1,2,3,..
    makeunique=true) 
    
    return df,df_params#prt_avg, df_params
    
end

function all_subs(model,subs)
    df_prt = DataFrame()
    df_params = DataFrame()
    for sub in subs
        try
            println(sub)
            prt, params = sub_params(model,sub)
            df_prt = vcat(df_prt,prt)
            df_params = vcat(df_params,params)
        catch err
        end
    end
    return df_prt, df_params
end 

all_subs (generic function with 2 methods)

In [6]:
subs=[  0,   1,   2,   3,   5,   8,   9,  10,  13,  15,  18,  19,  21,
        22,  23,  25,  28,  29,  30,  31,  32,  33,  34,  37,  39,  40,
        41,  44,  45,  46,  47,  48,  50,  53,  54,  55,  56,  57,  58,
        59,  64,  65,  69,  70,  71,  75,  76,  77,  78,  80,  81,  82,
        85,  89,  92,  94,  96,  97,  99, 100, 101, 104, 105, 106, 107,
       108, 110, 112, 113, 115, 116, 117, 119, 120, 121, 123, 124, 126,
       127, 128, 132, 134, 135, 136, 137, 138, 141, 142, 143, 146, 148,
       151, 154, 158, 159, 161, 162, 163, 164, 165, 167, 168, 169, 170,
       173, 175, 177, 182, 183, 184, 188, 190, 192, 195, 196, 197]

prt,params = all_subs("adpative_discount",subs)
CSV.write("model_results/prt_val_adaptive_discount_10_10_3.csv",prt)
CSV.write("model_results/params_val_adaptive_discount_10_10_3.csv",params)

prt,params = all_subs("mvt",subs)
CSV.write("model_results/prt_val_MVT_learn.csv",prt)
CSV.write("model_results/params_val_MVT_learn.csv",params)

prt,params = all_subs("td",subs)
CSV.write("model_results/prt_val_TD.csv",prt)
CSV.write("model_results/params_val_TD.csv",params)

0
1
2
3
5
8
9
10
13
15
18
19
21
22
23
25
28
29
30
31
32
33
34
37
39
40
41
44
45
46
47
48
50
53
54
55
56
57
58
59
64
65
69
70
71
75
76
77
78
80
81
82
85
89
92
94
96
97
99
100
101
104
105
106
107
108
110
112
113
115
116
117
119
120
121
123
124
126
127
128
132
134
135
136
137
138
141
142
143
146
148
151
154
158
159
161
162
163
164
165
167
168
169
170
173
175
177
182
183
184
188
190
192
195
196
197
0
1
2
3
5
8
9
10
13
15
18
19
21
22
23
25
28
29
30
31
32
33
34
37
39
40
41
44
45
46
47
48
50
53
54
55
56
57
58
59
64
65
69
70
71
75
76
77
78
80
81
82
85
89
92
94
96
97
99
100
101
104
105
106
107
108
110
112
113
115
116
117
119
120
121
123
124
126
127
128
132
134
135
136
137
138
141
142
143
146
148
151
154
158
159
161
162
163
164
165
167
168
169
170
173
175
177
182
183
184
188
190
192
195
196
197
0
1
2
3
5
8
9
10
13
15
18
19
21
22
23
25
28
29
30
31
32
33
34
37
39
40
41
44
45
46
47
48
50
53
54
55
56
57
58
59
64
65
69
70
71
75
76
77
78
80
81
82
85
89
92
94
96
97
99
100
101
104
105
106
107
108
110
11

"model_results/params_val_TD.csv"