In [2]:
using CSV
using DataFrames
using .IAI
using Random

# ---------------------------------------------------------
# Helper: get ENV with default
# ---------------------------------------------------------
getenv(key::AbstractString, default::AbstractString) = haskey(ENV, key) ? ENV[key] : default

# ---------------------------------------------------------
# 1. Resolve config from environment (mirroring Python script)
# ---------------------------------------------------------

train_data_file = getenv("TRAIN_DATA_FILE", "../data_gen/train_countdown_results_with_prompt_augmented_xgb.csv")
test_data_file  = getenv("TEST_DATA_FILE",  "../data_gen/test_countdown_results_with_prompt_gemini.csv")

# Column with the true label; default is "correct" to match your Python code
label_col_name = "correct"

# Random seed (optional, for reproducibility)
seed = parse(Int, getenv("RANDOM_STATE", "42"))

println("===================================================================")
println("OPTIMAL CLASSIFICATION TREE (IAI) - UNIFIED TRAIN/TEST")
println("===================================================================")
println("Train Data File: ", train_data_file)
println("Test Data File:  ", test_data_file)
println("Label column:    ", label_col_name)
println("Random seed:     ", seed)
println("===================================================================")

# ---------------------------------------------------------
# 2. Load train and test DataFrames
# ---------------------------------------------------------

println("Loading training data...")
train_df = CSV.read(train_data_file, DataFrame)
println("Training samples: ", nrow(train_df))

println("Training columns:")
println(names(train_df))

println("Loading test data...")
test_df = CSV.read(test_data_file, DataFrame)
println("Test samples:     ", nrow(test_df))

# Check label column
@assert label_col_name in names(train_df) "Label column $(label_col_name) not found in train data."
@assert label_col_name in names(test_df)  "Label column $(label_col_name) not found in test data."

# Ensure we only use columns that appear in BOTH train and test
common_cols = intersect(names(train_df), names(test_df))

train_df = train_df[:, common_cols]
test_df  = test_df[:, common_cols]

println("Number of columns common to train and test: ", length(common_cols))

# ---------------------------------------------------------
# 3. Build feature matrix X and label y
#    - Features: columns starting with "inst_" or "prompt_"
#    - Label:    column label_col_name
# ---------------------------------------------------------

feature_cols = [c for c in common_cols if startswith(String(c), "inst_") || startswith(String(c), "prompt_")]

if isempty(feature_cols)
    error("No feature columns found starting with 'inst_' or 'prompt_'. " *
          "Check that your CSV has those columns.")
end

println("Feature columns used (", length(feature_cols), "):")
println(feature_cols)

X_train = train_df[:, feature_cols]
y_train = train_df[:, label_col_name]

X_test  = test_df[:, feature_cols]
y_test  = test_df[:, label_col_name]

println("Train size: ", nrow(X_train))
println("Test size:  ", nrow(X_test))

# ---------------------------------------------------------
# 4. Fit Optimal Classification Tree on FULL training set
# ---------------------------------------------------------

Random.seed!(seed)

# Define a small hyperparameter grid (shallow interpretable trees)
grid = IAI.GridSearch(
    IAI.OptimalTreeClassifier(
        random_seed = seed,
        criterion   = :gini,   # you can switch to :misclassification if you prefer
    ),
    max_depth = 2:5,
    minbucket = 5:5:25,
)

println("\nFitting OptimalTreeClassifier grid search on full training set...")
IAI.fit!(grid, X_train, y_train)

OPTIMAL CLASSIFICATION TREE (IAI) - UNIFIED TRAIN/TEST
Train Data File: ../data_gen/train_countdown_results_with_prompt_augmented_xgb.csv
Test Data File:  ../data_gen/test_countdown_results_with_prompt_gemini.csv
Label column:    correct
Random seed:     42
Loading training data...
Training samples: 43445
Training columns:
["prompt_id", "instance_id", "prompt", "correct", "message", "inst_n_numbers", "inst_range", "inst_std", "inst_count_small", "inst_count_large", "inst_count_duplicates", "inst_count_even", "inst_count_odd", "inst_count_div_2", "inst_count_div_3", "inst_count_div_5", "inst_count_div_7", "inst_count_primes", "inst_distance_simple", "inst_distance_max", "inst_distance_avg", "inst_easy_pairs", "inst_log_target", "inst_expr_depth", "inst_count_add", "inst_count_sub", "inst_count_mul", "inst_count_div", "inst_noncomm_ops", "numbers", "target", "solution", "prompt_paraphrasing", "prompt_role-specification", "prompt_reasoning-trigger", "prompt_chain-of-thought", "prompt_self

In [3]:
best = IAI.get_learner(grid)

println("\n=================== Learned Optimal Classification Tree ===================")
display(best)
println("===========================================================================")

# ---------------------------------------------------------
# 5. Evaluate on train and test sets
# ---------------------------------------------------------

train_acc = IAI.score(best, X_train, y_train, criterion = :accuracy, positive_label = 1)
test_acc  = IAI.score(best, X_test,  y_test,  criterion = :accuracy, positive_label = 1)

println("\nTrain accuracy: ", train_acc)
println("Test accuracy:  ", test_acc)

# If binary labels, also compute AUC (optional)
unique_labels = unique(y_test)
if length(unique_labels) == 2
    try
        test_auc = IAI.score(best, X_test, y_test, criterion = :auc)
        println("Test AUC:       ", test_auc)
    catch e
        @warn "Could not compute AUC: $e"
    end
end

println("\nDone.")



Train accuracy: 0.9112441017378294
Test accuracy:  0.8087318087318087
Test AUC:       0.8680696429839999

Done.
