In [None]:
using Pkg
Pkg.activate(".")
using AxisKeys
using Revise
using EasyHybrid
using Lux
using Optimisers
using GLMakie
using Random
using LuxCore
using CSV, DataFrames
using EasyHybrid.MLUtils
using Statistics
using Plots
using Flux
using NNlib 
using JLD2

In [None]:
testid = "01_uniNN"
version = "v20251125";
results_dir = joinpath(@__DIR__, "eval");
target_names = [:BD, :SOCconc, :CF, :SOCdensity];

# input
df = CSV.read(joinpath(@__DIR__, "data/lucas_preprocessed_$version.csv"), DataFrame; normalizenames=true)

# scales
scalers = Dict(
    :SOCconc   => 0.151, # g/kg, log(x+1)*0.151
    :CF        => 0.263, # percent, log(x+1)*0.263
    :BD        => 0.529, # g/cm3, x*0.529
    :SOCdensity => 0.167, # kg/m3, log(x)*0.167
);

# predictor
predictors = Symbol.(names(df))[18:end-6]; # CHECK EVERY TIME 
nf = length(predictors)

# search space
hidden_configs = [ 
    (512, 256, 128, 64, 32, 16),
    (512, 256, 128, 64, 32), 
    (256, 128, 64, 32, 16),
    (256, 128, 64, 32),
    (256, 128, 64),
    (128, 64, 32, 16),
    (128, 64, 32),
    (64, 32, 16)
];
batch_sizes = [128, 256, 512];
lrs = [1e-3, 5e-4, 1e-4];
activations = [relu, tanh, swish, gelu];

configs = [(h=h, bs=bs, lr=lr, act=act)
    for h in hidden_configs
    for bs in batch_sizes
    for lr in lrs
    for act in activations]

# cross-validation
k = 5;
folds = make_folds(df, k = k, shuffle = true);
rlt_list_param = Vector{DataFrame}(undef, k)
rlt_list_pred = Vector{DataFrame}(undef, k)  

@info "Threads: $(Threads.nthreads())"


In [None]:
@time for test_fold in 1:k
    @info "Fold $test_fold"
    # -----------------------------
    # Split training / test sets
    # -----------------------------
    train_folds = setdiff(1:k, test_fold)
    train_idx = findall(in(train_folds), folds)
    train_df = df[train_idx, :]
    test_idx  = findall(==(test_fold), folds)
    test_df_full = df[test_idx, :]

    # Storage for this fold
    fold_params = DataFrame()

    # -----------------------------
    # Loop over single target
    # -----------------------------
    for tgt in target_names
        @info "Target $tgt"

        # ----- train: drop missing -----
        train_df_t = dropmissing(train_df, tgt)
        if nrow(train_df_t) == 0
            @warn "No training rows for $tgt — filling NaN for all rows"
            test_df_full[!, Symbol("pred_", tgt)] = fill(NaN32, nrow(test_df_full))
            continue
        end

        lk = ReentrantLock()
        best_loss   = Inf
        best_cfg    = nothing
        best_rlt    = nothing
        best_model  = nothing
        best_model_path = nothing
        try
            Threads.@threads for i in 1:length(configs)
                cfg = configs[i]
                h  = cfg.h
                bs = cfg.bs
                lr = cfg.lr
                act = cfg.act
                println("Testing h=$h, bs=$bs, lr=$lr, activation=$act")
    
                nn_local = constructNNModel(
                    predictors, [tgt];
                    hidden_layers = collect(h),
                    activation = act,
                    scale_nn_outputs = true,
                    input_batchnorm = true
                )
    
                rlt = train(
                    nn_local, train_df_t, ();
                    nepochs = 200,
                    batchsize = bs,
                    opt = AdamW(lr),
                    training_loss = :mse,
                    loss_types = [:mse, :r2],
                    shuffleobs = true,
                    file_name = "$(testid)_$(tgt)_config$(i)_fold$(test_fold).jld2",
                    patience = 15,
                    return_model = :best,
                    plotting = false,
                    show_progress = false,
                    hybrid_name = "$(testid)_$(tgt)_config$(i)_fold$(test_fold)" 
                )
    
                if rlt.best_loss < best_loss
                    best_loss = rlt.best_loss
                    best_cfg  = (h=h, bs=bs, lr=lr, act=act)
                    best_rlt  = rlt
                    best_model   = deepcopy(nn_local)
                    best_model_path = "$(testid)_$(tgt)_config$(i)_fold$(test_fold)" 
                end
            catch err
                @error "Thread $i crashed (target=$tgt)" exception=err
            end
        end

        agg = :sum
        r2s  = map(vh -> getproperty(vh, agg), best_rlt.val_history.r2)
        mses = map(vh -> getproperty(vh, agg), best_rlt.val_history.mse)
        be = max(best_rlt.best_epoch, 1)

        push!(fold_params, (
            fold       = test_fold,
            target     = String(tgt),
            h          = string(best_cfg.h),
            bs         = best_cfg.bs,
            lr         = best_cfg.lr,
            act        = string(best_cfg.act),
            r2         = r2s[be],
            mse        = mses[be],
            best_epoch = be,
            best_model_path = best_model_path
        ))

        ps, st = best_rlt.ps, best_rlt.st
        
        try
            # remove missing rows for the current target
            test_df_t = dropmissing(test_df_full, tgt)

            # prepare model input
            x_test, _ = prepare_data(best_model, test_df_t)
            ŷ, _ = best_model(x_test, ps, LuxCore.testmode(st))

            preds_clean = ŷ[tgt]  # predictions on filtered rows
            rids_clean  = test_df_t.row_id # row_ids for those rows
            
            pred_df = DataFrame(
                row_id = rids_clean,
                Symbol("pred_", tgt) => preds_clean
            )
            
            test_df_full = leftjoin(
                test_df_full,
                pred_df,
                on = :row_id,
                makeunique = true
            )
            
            replace!(test_df_full[!, Symbol("pred_", tgt)], missing => NaN32)

        catch err
            @warn "Prediction failed for $tgt on fold $test_fold — using NaN"
            test_df_full[!, Symbol("pred_", tgt)] = fill(NaN32, nrow(test_df_full))
        end
        
        # -------------------------------------------------------
        # Sanity check after join
        # -------------------------------------------------------

        colname = Symbol("pred_", tgt)

        if colname in names(test_df_full)
            # check how many predictions are non-NaN
            pred_col = test_df_full[!, colname]
            n_nonmissing = count(!isnan, pred_col)

            @info "[Fold $test_fold | $tgt] Join completed."
            @info "  → Column added: $(colname)"
            @info "  → Non-NaN predictions: $n_nonmissing / $(nrow(test_df_full))"

            # Print first few non-NaN predictions
            valid_idx = findall(!isnan, pred_col)

            if isempty(valid_idx)
                @warn "  All predictions are NaN for $tgt on fold $test_fold."
            end

        else
            @error "[Fold $test_fold | $tgt] Column $(colname) missing after join!"
        end

    end

    # save results for this fold
    rlt_list_param[test_fold] = fold_params
    rlt_list_pred[test_fold]  = test_df_full
end


# final combined outputs
rlt_param = vcat(rlt_list_param...)
rlt_pred  = vcat(rlt_list_pred...)

CSV.write(joinpath(results_dir, "$(testid)_cv.pred_$version.csv"), rlt_pred)
CSV.write(joinpath(results_dir, "$(testid)_hyperparams_$version.csv"), rlt_param)