Model manipulation

1. Traditional HDP (objects=percepts)
2. Object-Aware HDP (objects can generate their percepts)

Input data manipulation

1. Category variance
2. Uniform vs. skewed token count distribution 

Test sets
1. Generalization: Novel items
2. Memory: Familiar items

Test: 
- Compute P(test item | category 1) vs. P(test item | category 2)
- P(test item | category 1) > P(test item | category 2) <=> Categorize test item as an instance of category 1, vice versa

In [10]:
using DataFrames
using JLD2
using Random

include("data_structures.jl");
include("object_aware_HDP.jl");
include("plotting.jl");
include("traditional_HDP.jl");

In [13]:
@load "train_test_splits.jld2" cat1_splits cat2_splits;
n_splits = length(cat1_splits);
n_cat1 = length(vcat((obj.percepts for obj in cat1_splits[1].test_objects)...));
n_cat2 = length(vcat((obj.percepts for obj in cat2_splits[1].test_objects)...));
n_rows = n_splits * (n_cat1 + n_cat2);

# Traditional HDP

In [None]:
hyper_trad = TraditionalHDPHyperparams(
    m0 = [0.0, 0.0], 
    k_clu = 1.0, 
    k_obj = 3.33,
    a0 = 1.0, 
    b0 = 1.0, 
    alpha = 1
);

## Uniform training set

In [25]:
results_trad_uniform = DataFrame(
    split      = Vector{Int}(undef, n_rows),
    true_cat   = Vector{Int}(undef, n_rows),     # 1 or 2
    test_idx   = Vector{Int}(undef, n_rows),
    logp_cat1  = Vector{Float64}(undef, n_rows),
    logp_cat2  = Vector{Float64}(undef, n_rows),
);

row = 0
for i in 1:n_splits
    cat1_train = vcat((obj.percepts for obj in cat1_splits[i].train_objects_uniform)...)
    cat2_train = vcat((obj.percepts for obj in cat2_splits[i].train_objects_uniform)...)

    cat1_z, cat1_clu = trad_hdp_cluster_update(cat1_train, hyper_trad; iters=1000)
    cat2_z, cat2_clu = trad_hdp_cluster_update(cat2_train, hyper_trad; iters=1000)

    # cat1 test items
    cat1_test = vcat((obj.percepts for obj in cat1_splits[i].test_objects)...)
    for j in 1:n_cat1
        lp1 = log_post_pred_x(cat1_test[j], cat1_clu, hyper_trad)
        lp2 = log_post_pred_x(cat1_test[j], cat2_clu, hyper_trad)

        row += 1
        results_trad_uniform.split[row]     = i
        results_trad_uniform.true_cat[row]  = 1
        results_trad_uniform.test_idx[row]  = j
        results_trad_uniform.logp_cat1[row] = lp1
        results_trad_uniform.logp_cat2[row] = lp2
    end

    cat2_test = vcat((obj.percepts for obj in cat2_splits[i].test_objects)...)
    for j in 1:n_cat2
        lp1 = log_post_pred_x(cat2_test[j], cat1_clu, hyper_trad)
        lp2 = log_post_pred_x(cat2_test[j], cat2_clu, hyper_trad)

        row += 1
        results_trad_uniform.split[row]     = i
        results_trad_uniform.true_cat[row]  = 2
        results_trad_uniform.test_idx[row]  = j
        results_trad_uniform.logp_cat1[row] = lp1
        results_trad_uniform.logp_cat2[row] = lp2
    end 
end 

In [36]:
results_trad_uniform.pred_cat = ifelse.(results_trad_uniform.logp_cat1 .> results_trad_uniform.logp_cat2, 1, 2);
results_trad_uniform.correct  = results_trad_uniform.true_cat .== results_trad_uniform.pred_cat;
mean(results_trad_uniform.correct)

0.6158333333333333

## Skewed training set

In [28]:
results_trad_skewed = DataFrame(
    split      = Vector{Int}(undef, n_rows),
    true_cat   = Vector{Int}(undef, n_rows),     # 1 or 2
    test_idx   = Vector{Int}(undef, n_rows),
    logp_cat1  = Vector{Float64}(undef, n_rows),
    logp_cat2  = Vector{Float64}(undef, n_rows),
);

row = 0
for i in 1:n_splits
    cat1_train = vcat((obj.percepts for obj in cat1_splits[i].train_objects_skewed)...)
    cat2_train = vcat((obj.percepts for obj in cat2_splits[i].train_objects_skewed)...)

    cat1_z, cat1_clu = trad_hdp_cluster_update(cat1_train, hyper_trad; iters=1000)
    cat2_z, cat2_clu = trad_hdp_cluster_update(cat2_train, hyper_trad; iters=1000)

    # cat1 test items
    cat1_test = vcat((obj.percepts for obj in cat1_splits[i].test_objects)...)
    for j in 1:n_cat1
        lp1 = log_post_pred_x(cat1_test[j], cat1_clu, hyper_trad)
        lp2 = log_post_pred_x(cat1_test[j], cat2_clu, hyper_trad)

        row += 1
        results_trad_skewed.split[row]     = i
        results_trad_skewed.true_cat[row]  = 1
        results_trad_skewed.test_idx[row]  = j
        results_trad_skewed.logp_cat1[row] = lp1
        results_trad_skewed.logp_cat2[row] = lp2
    end

    cat2_test = vcat((obj.percepts for obj in cat2_splits[i].test_objects)...)
    for j in 1:n_cat2
        lp1 = log_post_pred_x(cat2_test[j], cat1_clu, hyper_trad)
        lp2 = log_post_pred_x(cat2_test[j], cat2_clu, hyper_trad)

        row += 1
        results_trad_skewed.split[row]     = i
        results_trad_skewed.true_cat[row]  = 2
        results_trad_skewed.test_idx[row]  = j
        results_trad_skewed.logp_cat1[row] = lp1
        results_trad_skewed.logp_cat2[row] = lp2
    end 
end 

In [35]:
results_trad_skewed.pred_cat = ifelse.(results_trad_skewed.logp_cat1 .> results_trad_skewed.logp_cat2, 1, 2);
results_trad_skewed.correct  = results_trad_skewed.true_cat .== results_trad_skewed.pred_cat;
mean(results_trad_skewed.correct)

0.5858333333333333

# Object-Aware HDP

## Uniform training set

In [31]:
hyper_obj = ObjectAwareHDPHyperparams(
    m0    = [0.0, 0.0],
    k_clu = 1.0,
    k_obj = 5.0,
    k_per = 10.0,
    a0    = 1.0,
    b0    = 1.0,
    alpha = 1.0
);

In [34]:
results_obj_uniform = DataFrame(
    split      = Vector{Int}(undef, n_rows),
    true_cat   = Vector{Int}(undef, n_rows),     # 1 or 2
    test_idx   = Vector{Int}(undef, n_rows),
    logp_cat1  = Vector{Float64}(undef, n_rows),
    logp_cat2  = Vector{Float64}(undef, n_rows),
);

row = 0
for i in 1:n_splits
    cat1_train = [obj.percepts for obj in cat1_splits[i].train_objects_uniform]
    cat2_train = [obj.percepts for obj in cat2_splits[i].train_objects_uniform]

    cat1_z, cat1_clu = object_aware_hdp_cluster_update(cat1_train, hyper_obj; iters=1000)
    cat2_z, cat2_clu = object_aware_hdp_cluster_update(cat2_train, hyper_obj; iters=1000)

    # cat1 test items
    cat1_test = vcat((obj.percepts for obj in cat1_splits[i].test_objects)...)
    for j in 1:n_cat1
        lp1 = log_post_pred_y(cat1_test[j], cat1_clu, hyper_obj)
        lp2 = log_post_pred_y(cat1_test[j], cat2_clu, hyper_obj)

        row += 1
        results_obj_uniform.split[row]     = i
        results_obj_uniform.true_cat[row]  = 1
        results_obj_uniform.test_idx[row]  = j
        results_obj_uniform.logp_cat1[row] = lp1
        results_obj_uniform.logp_cat2[row] = lp2
    end

    cat2_test = vcat((obj.percepts for obj in cat2_splits[i].test_objects)...)
    for j in 1:n_cat2
        lp1 = log_post_pred_y(cat2_test[j], cat1_clu, hyper_obj)
        lp2 = log_post_pred_y(cat2_test[j], cat2_clu, hyper_obj)

        row += 1
        results_obj_uniform.split[row]     = i
        results_obj_uniform.true_cat[row]  = 2
        results_obj_uniform.test_idx[row]  = j
        results_obj_uniform.logp_cat1[row] = lp1
        results_obj_uniform.logp_cat2[row] = lp2
    end 
end 

In [37]:
results_obj_uniform.pred_cat = ifelse.(results_obj_uniform.logp_cat1 .> results_obj_uniform.logp_cat2, 1, 2);
results_obj_uniform.correct  = results_obj_uniform.true_cat .== results_obj_uniform.pred_cat;
mean(results_obj_uniform.correct)

0.6166666666666667

# Skewed training set

In [38]:
results_obj_skewed = DataFrame(
    split      = Vector{Int}(undef, n_rows),
    true_cat   = Vector{Int}(undef, n_rows),     # 1 or 2
    test_idx   = Vector{Int}(undef, n_rows),
    logp_cat1  = Vector{Float64}(undef, n_rows),
    logp_cat2  = Vector{Float64}(undef, n_rows),
);

row = 0
for i in 1:n_splits
    cat1_train = [obj.percepts for obj in cat1_splits[i].train_objects_skewed]
    cat2_train = [obj.percepts for obj in cat2_splits[i].train_objects_skewed]

    cat1_z, cat1_clu = object_aware_hdp_cluster_update(cat1_train, hyper_obj; iters=1000)
    cat2_z, cat2_clu = object_aware_hdp_cluster_update(cat2_train, hyper_obj; iters=1000)

    # cat1 test items
    cat1_test = vcat((obj.percepts for obj in cat1_splits[i].test_objects)...)
    for j in 1:n_cat1
        lp1 = log_post_pred_y(cat1_test[j], cat1_clu, hyper_obj)
        lp2 = log_post_pred_y(cat1_test[j], cat2_clu, hyper_obj)

        row += 1
        results_obj_skewed.split[row]     = i
        results_obj_skewed.true_cat[row]  = 1
        results_obj_skewed.test_idx[row]  = j
        results_obj_skewed.logp_cat1[row] = lp1
        results_obj_skewed.logp_cat2[row] = lp2
    end

    cat2_test = vcat((obj.percepts for obj in cat2_splits[i].test_objects)...)
    for j in 1:n_cat2
        lp1 = log_post_pred_y(cat2_test[j], cat1_clu, hyper_obj)
        lp2 = log_post_pred_y(cat2_test[j], cat2_clu, hyper_obj)

        row += 1
        results_obj_skewed.split[row]     = i
        results_obj_skewed.true_cat[row]  = 2
        results_obj_skewed.test_idx[row]  = j
        results_obj_skewed.logp_cat1[row] = lp1
        results_obj_skewed.logp_cat2[row] = lp2
    end 
end 

In [39]:
results_obj_skewed.pred_cat = ifelse.(results_obj_skewed.logp_cat1 .> results_obj_skewed.logp_cat2, 1, 2);
results_obj_skewed.correct  = results_obj_skewed.true_cat .== results_obj_skewed.pred_cat;
mean(results_obj_skewed.correct)

0.6158333333333333

In [44]:
dfs = [results_trad_uniform, results_trad_skewed,
          results_obj_uniform,  results_obj_skewed]

models = ["trad", "trad", "obj", "obj"]
types  = ["uniform", "skewed", "uniform", "skewed"]

for (df, m, t) in zip(dfs, models, types)
    df[!, :model] = fill(m, nrow(df))   # or repeat([m], nrow(df))
    df[!, :training_type] = fill(t, nrow(df))
end

results_all = vcat(dfs...; cols = :union)

using CSV
CSV.write("results_all.csv", results_all)

"results_all.csv"

TODO: Quinn style test

trained on category 1, test on novel objects from cat 1 vs. novel from cat 2
novel from cat 2 should be more surprising, i.e., lower prob