In [1]:
using CSV
using DataFrames
using JLD2
using Random

include("scripts/data_structures.jl");
include("scripts/object_aware_HDP.jl");
include("scripts/plotting.jl");
include("scripts/traditional_HDP.jl");

In [2]:
@load "data/train_test_splits.jld2" cat1_splits cat2_splits;
n_splits = length(cat1_splits);
n_cat1 = length(vcat((obj.percepts for obj in cat1_splits[1].test_objects)...));
n_cat2 = length(vcat((obj.percepts for obj in cat2_splits[1].test_objects)...));
n_rows = n_splits * (n_cat1 + n_cat2);
d = length(cat1_splits[1].test_objects[1].percepts[1]);

In [3]:
K_CLU = 1.0;
K_OBJ = 5.0;
K_PER = 5.0;
ALPHA = 1.0;
rng = MersenneTwister(7);
n_iter = 1000;

# Traditional HDP

In [4]:
hyper_trad = TraditionalHDPHyperparams(
    m0 = zeros(d), 
    k_clu = K_CLU, 
    k_obj = 1 / (1/K_OBJ + 1/K_PER),
    a0 = 1.0, 
    b0 = 1.0, 
    alpha = ALPHA
);

## Uniform training set

In [5]:
results_trad_uniform = DataFrame(
    split = Vector{Int}(undef, n_rows),
    true_cat = Vector{Int}(undef, n_rows),    
    test_idx = Vector{Int}(undef, n_rows),
    test_x1 = Vector{Float64}(undef, n_rows),
    test_x2 = Vector{Float64}(undef, n_rows),
    logp_cat1 = Vector{Float64}(undef, n_rows),
    logp_cat2 = Vector{Float64}(undef, n_rows),
);

row = 0
for i in 1:n_splits
    cat1_train = vcat((obj.percepts for obj in cat1_splits[i].train_objects_uniform)...)
    cat2_train = vcat((obj.percepts for obj in cat2_splits[i].train_objects_uniform)...)

    cat1_z, cat1_clu = trad_hdp_cluster_update(cat1_train, hyper_trad; iters=n_iter, rng=rng)
    cat2_z, cat2_clu = trad_hdp_cluster_update(cat2_train, hyper_trad; iters=n_iter, rng=rng)

    # cat1 test items
    cat1_test = vcat((obj.percepts for obj in cat1_splits[i].test_objects)...)
    for j in 1:n_cat1
        lp1 = log_post_pred_x(cat1_test[j], cat1_clu, hyper_trad)
        lp2 = log_post_pred_x(cat1_test[j], cat2_clu, hyper_trad)

        row += 1
        results_trad_uniform.split[row] = i
        results_trad_uniform.true_cat[row] = 1
        results_trad_uniform.test_idx[row] = j
        results_trad_uniform.test_x1[row] = round(cat1_test[j][1], digits=3)
        results_trad_uniform.test_x2[row] = round(cat1_test[j][2], digits=3)
        results_trad_uniform.logp_cat1[row] = lp1
        results_trad_uniform.logp_cat2[row] = lp2
    end

    cat2_test = vcat((obj.percepts for obj in cat2_splits[i].test_objects)...)
    for j in 1:n_cat2
        lp1 = log_post_pred_x(cat2_test[j], cat1_clu, hyper_trad)
        lp2 = log_post_pred_x(cat2_test[j], cat2_clu, hyper_trad)

        row += 1
        results_trad_uniform.split[row] = i
        results_trad_uniform.true_cat[row] = 2
        results_trad_uniform.test_idx[row] = j
        results_trad_uniform.test_x1[row] = round(cat2_test[j][1], digits=3)
        results_trad_uniform.test_x2[row] = round(cat2_test[j][2], digits=3)
        results_trad_uniform.logp_cat1[row] = lp1
        results_trad_uniform.logp_cat2[row] = lp2
    end 
end 

results_trad_uniform.pred_cat = ifelse.(results_trad_uniform.logp_cat1 .> results_trad_uniform.logp_cat2, 1, 2);
results_trad_uniform.correct = results_trad_uniform.true_cat .== results_trad_uniform.pred_cat;
mean(results_trad_uniform.correct)

0.7625

## Skewed training set

In [6]:
results_trad_skewed = DataFrame(
    split = Vector{Int}(undef, n_rows),
    true_cat = Vector{Int}(undef, n_rows),     
    test_idx = Vector{Int}(undef, n_rows),
    test_x1 = Vector{Float64}(undef, n_rows),
    test_x2 = Vector{Float64}(undef, n_rows),
    logp_cat1 = Vector{Float64}(undef, n_rows),
    logp_cat2 = Vector{Float64}(undef, n_rows),
);

row = 0
for i in 1:n_splits
    cat1_train = vcat((obj.percepts for obj in cat1_splits[i].train_objects_skewed)...)
    cat2_train = vcat((obj.percepts for obj in cat2_splits[i].train_objects_skewed)...)

    cat1_z, cat1_clu = trad_hdp_cluster_update(cat1_train, hyper_trad; iters=n_iter, rng=rng)
    cat2_z, cat2_clu = trad_hdp_cluster_update(cat2_train, hyper_trad; iters=n_iter, rng=rng)

    # cat1 test items
    cat1_test = vcat((obj.percepts for obj in cat1_splits[i].test_objects)...)
    for j in 1:n_cat1
        lp1 = log_post_pred_x(cat1_test[j], cat1_clu, hyper_trad)
        lp2 = log_post_pred_x(cat1_test[j], cat2_clu, hyper_trad)

        row += 1
        results_trad_skewed.split[row] = i
        results_trad_skewed.true_cat[row] = 1
        results_trad_skewed.test_idx[row] = j
        results_trad_skewed.test_x1[row] = round(cat1_test[j][1], digits=3)
        results_trad_skewed.test_x2[row] = round(cat1_test[j][2], digits=3)
        results_trad_skewed.logp_cat1[row] = lp1
        results_trad_skewed.logp_cat2[row] = lp2
    end

    cat2_test = vcat((obj.percepts for obj in cat2_splits[i].test_objects)...)
    for j in 1:n_cat2
        lp1 = log_post_pred_x(cat2_test[j], cat1_clu, hyper_trad)
        lp2 = log_post_pred_x(cat2_test[j], cat2_clu, hyper_trad)

        row += 1
        results_trad_skewed.split[row] = i
        results_trad_skewed.true_cat[row] = 2
        results_trad_skewed.test_idx[row] = j
        results_trad_skewed.test_x1[row] = round(cat2_test[j][1], digits=3)
        results_trad_skewed.test_x2[row] = round(cat2_test[j][2], digits=3)
        results_trad_skewed.logp_cat1[row] = lp1
        results_trad_skewed.logp_cat2[row] = lp2
    end 
end 

results_trad_skewed.pred_cat = ifelse.(results_trad_skewed.logp_cat1 .> results_trad_skewed.logp_cat2, 1, 2);
results_trad_skewed.correct = results_trad_skewed.true_cat .== results_trad_skewed.pred_cat;
mean(results_trad_skewed.correct)

0.7435

# Object-Aware HDP

## Uniform training set

In [7]:
hyper_obj = ObjectAwareHDPHyperparams(
    m0 = zeros(d),
    k_clu = K_CLU,
    k_obj = K_OBJ,
    k_per = K_PER,
    a0 = 1.0,
    b0 = 1.0,
    alpha = ALPHA
);

In [8]:
results_obj_uniform = DataFrame(
    split = Vector{Int}(undef, n_rows),
    true_cat = Vector{Int}(undef, n_rows),     
    test_idx = Vector{Int}(undef, n_rows),
    test_x1 = Vector{Float64}(undef, n_rows),
    test_x2 = Vector{Float64}(undef, n_rows),
    logp_cat1 = Vector{Float64}(undef, n_rows),
    logp_cat2 = Vector{Float64}(undef, n_rows),
);

row = 0
for i in 1:n_splits
    cat1_train = [obj.percepts for obj in cat1_splits[i].train_objects_uniform]
    cat2_train = [obj.percepts for obj in cat2_splits[i].train_objects_uniform]

    cat1_z, cat1_clu = object_aware_hdp_cluster_update(cat1_train, hyper_obj; iters=n_iter, rng=rng)
    cat2_z, cat2_clu = object_aware_hdp_cluster_update(cat2_train, hyper_obj; iters=n_iter, rng=rng)

    # cat1 test items
    cat1_test = vcat((obj.percepts for obj in cat1_splits[i].test_objects)...)
    for j in 1:n_cat1
        lp1 = log_post_pred_y(cat1_test[j], cat1_clu, hyper_obj)
        lp2 = log_post_pred_y(cat1_test[j], cat2_clu, hyper_obj)

        row += 1
        results_obj_uniform.split[row] = i
        results_obj_uniform.true_cat[row] = 1
        results_obj_uniform.test_idx[row] = j
        results_obj_uniform.test_x1[row] = round(cat1_test[j][1], digits=3)
        results_obj_uniform.test_x2[row] = round(cat1_test[j][2], digits=3)
        results_obj_uniform.logp_cat1[row] = lp1
        results_obj_uniform.logp_cat2[row] = lp2
    end

    cat2_test = vcat((obj.percepts for obj in cat2_splits[i].test_objects)...)
    for j in 1:n_cat2
        lp1 = log_post_pred_y(cat2_test[j], cat1_clu, hyper_obj)
        lp2 = log_post_pred_y(cat2_test[j], cat2_clu, hyper_obj)

        row += 1
        results_obj_uniform.split[row] = i
        results_obj_uniform.true_cat[row] = 2
        results_obj_uniform.test_idx[row] = j
        results_obj_uniform.test_x1[row] = round(cat2_test[j][1], digits=3)
        results_obj_uniform.test_x2[row] = round(cat2_test[j][2], digits=3)
        results_obj_uniform.logp_cat1[row] = lp1
        results_obj_uniform.logp_cat2[row] = lp2
    end 
end 

results_obj_uniform.pred_cat = ifelse.(results_obj_uniform.logp_cat1 .> results_obj_uniform.logp_cat2, 1, 2);
results_obj_uniform.correct = results_obj_uniform.true_cat .== results_obj_uniform.pred_cat;
mean(results_obj_uniform.correct)

0.77

## Skewed training set

In [9]:
results_obj_skewed = DataFrame(
    split = Vector{Int}(undef, n_rows),
    true_cat = Vector{Int}(undef, n_rows),     
    test_idx = Vector{Int}(undef, n_rows),
    test_x1 = Vector{Float64}(undef, n_rows),
    test_x2 = Vector{Float64}(undef, n_rows),
    logp_cat1 = Vector{Float64}(undef, n_rows),
    logp_cat2 = Vector{Float64}(undef, n_rows),
);

row = 0
for i in 1:n_splits
    cat1_train = [obj.percepts for obj in cat1_splits[i].train_objects_skewed]
    cat2_train = [obj.percepts for obj in cat2_splits[i].train_objects_skewed]

    cat1_z, cat1_clu = object_aware_hdp_cluster_update(cat1_train, hyper_obj; iters=n_iter, rng=rng)
    cat2_z, cat2_clu = object_aware_hdp_cluster_update(cat2_train, hyper_obj; iters=n_iter, rng=rng)

    # cat1 test items
    cat1_test = vcat((obj.percepts for obj in cat1_splits[i].test_objects)...)
    for j in 1:n_cat1
        lp1 = log_post_pred_y(cat1_test[j], cat1_clu, hyper_obj)
        lp2 = log_post_pred_y(cat1_test[j], cat2_clu, hyper_obj)

        row += 1
        results_obj_skewed.split[row] = i
        results_obj_skewed.true_cat[row] = 1
        results_obj_skewed.test_idx[row] = j
        results_obj_skewed.test_x1[row] = round(cat1_test[j][1], digits=3)
        results_obj_skewed.test_x2[row] = round(cat1_test[j][2], digits=3)
        results_obj_skewed.logp_cat1[row] = lp1
        results_obj_skewed.logp_cat2[row] = lp2
    end

    cat2_test = vcat((obj.percepts for obj in cat2_splits[i].test_objects)...)
    for j in 1:n_cat2
        lp1 = log_post_pred_y(cat2_test[j], cat1_clu, hyper_obj)
        lp2 = log_post_pred_y(cat2_test[j], cat2_clu, hyper_obj)

        row += 1
        results_obj_skewed.split[row] = i
        results_obj_skewed.true_cat[row] = 2
        results_obj_skewed.test_idx[row] = j
        results_obj_skewed.test_x1[row] = round(cat2_test[j][1], digits=3)
        results_obj_skewed.test_x2[row] = round(cat2_test[j][2], digits=3)
        results_obj_skewed.logp_cat1[row] = lp1
        results_obj_skewed.logp_cat2[row] = lp2
    end 
end 

results_obj_skewed.pred_cat = ifelse.(results_obj_skewed.logp_cat1 .> results_obj_skewed.logp_cat2, 1, 2);
results_obj_skewed.correct = results_obj_skewed.true_cat .== results_obj_skewed.pred_cat;
mean(results_obj_skewed.correct)

0.7696666666666667

In [10]:
# Merge all dataframes and save results
dfs = [results_trad_uniform, results_trad_skewed,
          results_obj_uniform,  results_obj_skewed]

models = ["trad", "trad", "obj", "obj"]
types = ["uniform", "skewed", "uniform", "skewed"]

for (df, m, t) in zip(dfs, models, types)
    df[!, :model] = fill(m, nrow(df))   # or repeat([m], nrow(df))
    df[!, :training_type] = fill(t, nrow(df))
end

results_all = vcat(dfs...; cols = :union)

CSV.write("data/results_all.csv", results_all)

"data/results_all.csv"