In [1]:
import argparse
import json
import os
import pickle
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional

import joblib
import numpy as np
import pandas as pd
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LogisticRegression as SkLR
from sklearn.metrics import (accuracy_score, roc_auc_score,
                             precision_recall_fscore_support, confusion_matrix)
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.svm import LinearSVC as SkLinearSVC
from sklearn.tree import DecisionTreeClassifier as SkDecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier as SkRandomForestClassifier
from sklearn.utils.class_weight import compute_class_weight

# Optional XGBoost baseline
try:
    from xgboost import XGBClassifier as SkXGBClassifier
    HAS_XGB = True
except Exception:
    HAS_XGB = False

# Concrete-ML imports
from concrete.ml.sklearn import LogisticRegression as CMLLogReg
from concrete.ml.sklearn.svm import LinearSVC as CMLLinearSVC
from concrete.ml.sklearn.tree import DecisionTreeClassifier as CMLDecisionTreeClassifier
from concrete.ml.sklearn.rf import RandomForestClassifier as CMLRandomForestClassifier
from concrete.ml.sklearn.xgb import XGBClassifier as CMLXGBClassifier
from concrete.ml.sklearn import NeuralNetClassifier as CMLNeuralNetClassifier

import torch.nn as nn

In [2]:
@dataclass
class FeatureSpec:
    numeric: List[str]
    categorical: List[str]
    boolean: List[str]


def infer_feature_spec(df: pd.DataFrame, target: str) -> FeatureSpec:
    candidate_num = [
        "farm_area_ha", "rain_mm_gs", "eo_ndvi_gs", "soil_quality_index",
        "input_cost_kes", "sales_kes", "yield_t_ha",
        "mpesa_txn_count_90d", "mpesa_inflow_kes_90d",
        "agritech_score", "loan_amount_kes", "tenor_months",
        "interest_rate_pct", "climate_risk_index"
    ]
    candidate_cat = ["county", "crop_primary", "crop_secondary"]
    candidate_bool = ["irrigated", "prior_default", "processor_contract", "insured", "gov_subsidy"]

    numeric = [c for c in candidate_num if c in df.columns and c != target]
    categorical = [c for c in candidate_cat if c in df.columns and c != target]
    boolean = [c for c in candidate_bool if c in df.columns and c != target]

    # Cast booleans to int upfront
    for b in boolean:
        if b in df.columns:
            df[b] = df[b].astype(int)

    return FeatureSpec(numeric, categorical, boolean)

In [3]:
def build_preprocessor(spec: FeatureSpec) -> ColumnTransformer:
    numeric_pipe = Pipeline(steps=[
        ("imputer", SimpleImputer(strategy="median")),
        ("scaler", StandardScaler(with_mean=True, with_std=True)),
    ])
    categorical_pipe = Pipeline(steps=[
        ("imputer", SimpleImputer(strategy="most_frequent")),
        ("ohe", OneHotEncoder(handle_unknown="ignore", sparse=False)),
    ])
    # Boolean passthrough (already int)
    boolean_pipe = "passthrough"

    pre = ColumnTransformer(
        transformers=[
            ("num", numeric_pipe, spec.numeric),
            ("cat", categorical_pipe, spec.categorical),
            ("bool", boolean_pipe, spec.boolean),
        ],
        remainder="drop",
        sparse_threshold=0.0,
    )
    return pre


def to_numpy(pre: ColumnTransformer, dfX: pd.DataFrame) -> np.ndarray:
    X = pre.transform(dfX)
    return np.asarray(X, dtype=np.float32)

In [4]:
def evaluate_binary(y_true: np.ndarray, y_pred: np.ndarray, y_score: Optional[np.ndarray], label: str) -> dict:
    acc = accuracy_score(y_true, y_pred)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
    cm = confusion_matrix(y_true, y_pred).tolist()
    out = {"label": label, "accuracy": acc, "precision": prec, "recall": rec, "f1": f1, "confusion_matrix": cm}
    if y_score is not None:
        try:
            out["roc_auc"] = roc_auc_score(y_true, y_score)
        except Exception:
            pass
    return out


def evaluate_with_time(predict_fn, X: np.ndarray, y_true: np.ndarray, label: str,
                       score_fn=None) -> dict:
    n = len(X)
    t0 = time.time()
    y_pred = predict_fn(X)
    t1 = time.time()
    elapsed = t1 - t0
    latency = elapsed / max(1, n)

    y_score = None
    if score_fn is not None:
        try:
            y_score = score_fn(X)
        except Exception:
            y_score = None

    metrics = evaluate_binary(y_true, y_pred, y_score, label)
    metrics["total_time_s"] = elapsed
    metrics["num_samples"] = n
    metrics["latency_per_sample_s"] = latency
    return metrics

In [5]:
def train_eval_concrete_model(name: str, model,
                              X_train, y_train,
                              X_test, y_test,
                              outdir: Path,
                              execute_all: bool,
                              execute_samples: int,
                              pickle_safe: bool = True,
                              sk_model=None) -> Dict[str, dict]:
    """
    Train/evaluate Concrete-ML model (clear, simulate, execute).
    Optionally: also train/evaluate sklearn plaintext equivalent if sk_model is given.
    """
    out: Dict[str, dict] = {}

    # -------- sklearn baseline (if provided) --------
    if sk_model is not None:
        sk_model.fit(X_train, y_train)

        try:
            joblib.dump(sk_model, outdir / f"{name}_sklearn.joblib")
        except Exception as e:
            print(f"[WARN] Could not save sklearn model {name}: {e}")

        out[f"{name}_sklearn_plaintext"] = evaluate_with_time(
            sk_model.predict,
            X_test, y_test,
            f"{name}_sklearn_plaintext",
            score_fn=(
                (lambda X: sk_model.predict_proba(X)[:, 1]) if hasattr(sk_model, "predict_proba")
                else (lambda X: sk_model.decision_function(X)) if hasattr(sk_model, "decision_function")
                else None
            ),
        )

    # -------- Concrete-ML model --------
    model.fit(X_train, y_train)

    if pickle_safe:
        try:
            with open(outdir / f"{name}_cml.pkl", "wb") as f:
                pickle.dump(model, f)
        except Exception as e:
            print(f"[WARN] Could not pickle {name} model: {e}. Skipping save.")

    # Clear-quantized
    out[f"{name}_clear_quantized"] = evaluate_with_time(
        model.predict,
        X_test, y_test,
        f"{name}_clear_quantized",
        score_fn=(lambda X: model.predict_proba(X)[:, 1]) if hasattr(model, "predict_proba") else None
    )

    # Compile
    t0 = time.time()
    model.compile(X_test)
    out[f"{name}_compile_time_s"] = time.time() - t0

    # FHE simulate
    out[f"{name}_fhe_simulate"] = evaluate_with_time(
        lambda X: model.predict(X, fhe="simulate"),
        X_test, y_test,
        f"{name}_fhe_simulate",
        score_fn=(lambda X: model.predict_proba(X, fhe="simulate")[:, 1]) if hasattr(model, "predict_proba") else None
    )

    # FHE execute (compulsory)
    if execute_all:
        subset, y_true_subset, subset_label = X_test, y_test, f"all_{len(X_test)}"
    else:
        k = max(1, min(execute_samples, len(X_test)))
        subset, y_true_subset, subset_label = X_test[:k], y_test[:k], f"subset_{k}"

    exec_metrics = evaluate_with_time(
        lambda X: model.predict(X, fhe="execute"),
        subset, y_true_subset,
        f"{name}_fhe_execute_{subset_label}",
        score_fn=None
    )
    exec_metrics["compile_time_s"] = out[f"{name}_compile_time_s"]
    out[f"{name}_fhe_execute"] = exec_metrics

    return out

In [6]:
    

outdir = Path("artifacts_multi")
outdir.mkdir(parents=True, exist_ok=True)

# Load
df = pd.read_csv("kenya_agri.csv")
target = "default_or_claim"
if target not in df.columns:
    raise ValueError(f"Target '{target}' not in CSV")

# Honour split if present
if "split" in df.columns:
    train_df = df[df["split"] == "train"].copy()
    test_df = df[df["split"] == "test"].copy()
else:
    train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df[target])
# train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df[target])


In [7]:
# Features
spec = infer_feature_spec(train_df, target=target)
pre = build_preprocessor(spec)
pre.fit(train_df[spec.numeric + spec.categorical + spec.boolean])

with open(outdir / "preprocessor.pkl", "wb") as f:
    pickle.dump({"preprocessor": pre, "feature_spec": spec}, f)

X_train = to_numpy(pre, train_df[spec.numeric + spec.categorical + spec.boolean])
X_test = to_numpy(pre, test_df[spec.numeric + spec.categorical + spec.boolean])
y_train = train_df[target].astype(int).to_numpy()
y_test = test_df[target].astype(int).to_numpy()

In [8]:
execute_all = True
results: Dict[str, dict] = {}
n_bits = 8
execute_samples = 64

In [9]:
# Logistic Regression
results["logreg"] = train_eval_concrete_model(
    "logreg",
    CMLLogReg(n_bits=n_bits, fit_intercept=True, max_iter=2000, class_weight="balanced"),
    X_train, y_train, X_test, y_test,
    outdir, execute_all, execute_samples,
    pickle_safe=True,
    sk_model=SkLR(max_iter=2000, class_weight="balanced")
)
results["logreg"]

{'logreg_sklearn_plaintext': {'label': 'logreg_sklearn_plaintext',
  'accuracy': 0.744,
  'precision': 0.14074074074074075,
  'recall': 0.6129032258064516,
  'f1': 0.22891566265060243,
  'confusion_matrix': [[353, 116], [12, 19]],
  'roc_auc': 0.7279730380356284,
  'total_time_s': 0.0004279613494873047,
  'num_samples': 500,
  'latency_per_sample_s': 8.559226989746093e-07},
 'logreg_clear_quantized': {'label': 'logreg_clear_quantized',
  'accuracy': 0.742,
  'precision': 0.13970588235294118,
  'recall': 0.6129032258064516,
  'f1': 0.2275449101796407,
  'confusion_matrix': [[352, 117], [12, 19]],
  'roc_auc': 0.7282137698603756,
  'total_time_s': 0.0006210803985595703,
  'num_samples': 500,
  'latency_per_sample_s': 1.2421607971191406e-06},
 'logreg_compile_time_s': 0.3832590579986572,
 'logreg_fhe_simulate': {'label': 'logreg_fhe_simulate',
  'accuracy': 0.742,
  'precision': 0.13970588235294118,
  'recall': 0.6129032258064516,
  'f1': 0.2275449101796407,
  'confusion_matrix': [[352, 1

In [10]:
# Linear SVM
results["linear_svm"] = train_eval_concrete_model(
    "linear_svm",
    CMLLinearSVC(n_bits=n_bits, max_iter=5000, class_weight="balanced"),
    X_train, y_train, X_test, y_test,
    outdir, execute_all, execute_samples,
    pickle_safe=True,
    sk_model=SkLinearSVC(max_iter=5000, class_weight="balanced")
)
results["linear_svm"]

{'linear_svm_sklearn_plaintext': {'label': 'linear_svm_sklearn_plaintext',
  'accuracy': 0.748,
  'precision': 0.14814814814814814,
  'recall': 0.6451612903225806,
  'f1': 0.24096385542168672,
  'confusion_matrix': [[354, 115], [11, 20]],
  'roc_auc': 0.7210949858999932,
  'total_time_s': 0.00017309188842773438,
  'num_samples': 500,
  'latency_per_sample_s': 3.4618377685546874e-07},
 'linear_svm_clear_quantized': {'label': 'linear_svm_clear_quantized',
  'accuracy': 0.748,
  'precision': 0.14814814814814814,
  'recall': 0.6451612903225806,
  'f1': 0.24096385542168672,
  'confusion_matrix': [[354, 115], [11, 20]],
  'roc_auc': 0.721266937203384,
  'total_time_s': 0.00026607513427734375,
  'num_samples': 500,
  'latency_per_sample_s': 5.321502685546875e-07},
 'linear_svm_compile_time_s': 0.2985382080078125,
 'linear_svm_fhe_simulate': {'label': 'linear_svm_fhe_simulate',
  'accuracy': 0.748,
  'precision': 0.14814814814814814,
  'recall': 0.6451612903225806,
  'f1': 0.24096385542168672,

In [11]:
# Decision Tree
results["decision_tree"] = train_eval_concrete_model(
    "decision_tree",
    CMLDecisionTreeClassifier(max_depth=3, class_weight="balanced", n_bits=n_bits),
    X_train, y_train, X_test, y_test,
    outdir, execute_all, execute_samples,
    pickle_safe=False,
    sk_model=SkDecisionTreeClassifier(max_depth=3, class_weight="balanced")
)
results["decision_tree"]

{'decision_tree_sklearn_plaintext': {'label': 'decision_tree_sklearn_plaintext',
  'accuracy': 0.736,
  'precision': 0.13138686131386862,
  'recall': 0.5806451612903226,
  'f1': 0.2142857142857143,
  'confusion_matrix': [[350, 119], [13, 18]],
  'roc_auc': 0.6393837265286471,
  'total_time_s': 0.00022983551025390625,
  'num_samples': 500,
  'latency_per_sample_s': 4.596710205078125e-07},
 'decision_tree_clear_quantized': {'label': 'decision_tree_clear_quantized',
  'accuracy': 0.732,
  'precision': 0.1347517730496454,
  'recall': 0.6129032258064516,
  'f1': 0.22093023255813954,
  'confusion_matrix': [[347, 122], [12, 19]],
  'roc_auc': 0.6510420248985488,
  'total_time_s': 0.001318216323852539,
  'num_samples': 500,
  'latency_per_sample_s': 2.636432647705078e-06},
 'decision_tree_compile_time_s': 1.2307240962982178,
 'decision_tree_fhe_simulate': {'label': 'decision_tree_fhe_simulate',
  'accuracy': 0.732,
  'precision': 0.1347517730496454,
  'recall': 0.6129032258064516,
  'f1': 0.22

In [12]:
# Random Forest
results["random_forest"] = train_eval_concrete_model(
    "random_forest",
    CMLRandomForestClassifier(n_estimators=10, max_depth=3, class_weight="balanced", n_bits=n_bits),
    X_train, y_train, X_test, y_test,
    outdir, execute_all, execute_samples,
    pickle_safe=False,
    sk_model=SkRandomForestClassifier(n_estimators=10, max_depth=3, class_weight="balanced")
)
results["random_forest"]

{'random_forest_sklearn_plaintext': {'label': 'random_forest_sklearn_plaintext',
  'accuracy': 0.806,
  'precision': 0.2,
  'recall': 0.7096774193548387,
  'f1': 0.3120567375886525,
  'confusion_matrix': [[381, 88], [9, 22]],
  'roc_auc': 0.7877777013549764,
  'total_time_s': 0.0006608963012695312,
  'num_samples': 500,
  'latency_per_sample_s': 1.3217926025390624e-06},
 'random_forest_clear_quantized': {'label': 'random_forest_clear_quantized',
  'accuracy': 0.764,
  'precision': 0.12173913043478261,
  'recall': 0.45161290322580644,
  'f1': 0.19178082191780824,
  'confusion_matrix': [[368, 101], [17, 14]],
  'roc_auc': 0.6899374097255657,
  'total_time_s': 0.0032651424407958984,
  'num_samples': 500,
  'latency_per_sample_s': 6.530284881591797e-06},
 'random_forest_compile_time_s': 1.326725959777832,
 'random_forest_fhe_simulate': {'label': 'random_forest_fhe_simulate',
  'accuracy': 0.764,
  'precision': 0.12173913043478261,
  'recall': 0.45161290322580644,
  'f1': 0.1917808219178082

In [13]:
# XGBoost
if HAS_XGB:
    results["xgboost"] = train_eval_concrete_model(
        "xgboost",
        CMLXGBClassifier(n_estimators=10, max_depth=3, n_bits=n_bits),
        X_train, y_train, X_test, y_test,
        outdir, execute_all, execute_samples,
        pickle_safe=False,
        sk_model=SkXGBClassifier(n_estimators=10, max_depth=3, use_label_encoder=False, eval_metric="logloss")
    )
else:
    results["xgboost"] = {"skipped": True}
results["xgboost"]

{'xgboost_sklearn_plaintext': {'label': 'xgboost_sklearn_plaintext',
  'accuracy': 0.938,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'confusion_matrix': [[469, 0], [31, 0]],
  'roc_auc': 0.7981979503404637,
  'total_time_s': 0.002543926239013672,
  'num_samples': 500,
  'latency_per_sample_s': 5.087852478027344e-06},
 'xgboost_clear_quantized': {'label': 'xgboost_clear_quantized',
  'accuracy': 0.938,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'confusion_matrix': [[469, 0], [31, 0]],
  'roc_auc': 0.7822752596464682,
  'total_time_s': 0.008324146270751953,
  'num_samples': 500,
  'latency_per_sample_s': 1.6648292541503907e-05},
 'xgboost_compile_time_s': 1.664928913116455,
 'xgboost_fhe_simulate': {'label': 'xgboost_fhe_simulate',
  'accuracy': 0.938,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'confusion_matrix': [[469, 0], [31, 0]],
  'roc_auc': 0.7824128206891808,
  'total_time_s': 9.113968849182129,
  'num_samples': 500,
  'latency_per_sample_s': 0.01822793

In [14]:
# QNN (safe to pickle)
qnn_params = dict(
    module__n_layers=3,
    module__activation_function=nn.ReLU,
    module__n_hidden_neurons_multiplier=4,
    module__n_w_bits=4,       # quantization bits for weights
    module__n_a_bits=4,       # quantization bits for activations
    module__n_accum_bits=10,   # quantization bits for accumulators > n_w_bits + n_a_bits (or slightly larger).
    max_epochs=50,
    verbose=1,
)
cml_qnn = CMLNeuralNetClassifier(**qnn_params)
results["concrete_qnn"] = train_eval_concrete_model(
    "concrete_qnn", cml_qnn, X_train, y_train, X_test, y_test, outdir, execute_all, execute_samples, pickle_safe=True
)
results["concrete_qnn"]

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6637[0m       [32m0.9450[0m        [35m0.5816[0m  0.1933
      2        [36m0.5116[0m       0.9450        [35m0.4224[0m  0.1446
      3        [36m0.3569[0m       0.9450        [35m0.2822[0m  0.1484
      4        [36m0.2514[0m       0.9450        [35m0.2187[0m  0.1518
      5        [36m0.2206[0m       0.9450        [35m0.2166[0m  0.1440
      6        [36m0.2187[0m       0.9450        0.2174  0.1588
      7        [36m0.2179[0m       0.9450        [35m0.2153[0m  0.1419
      8        [36m0.2169[0m       0.9450        0.2159  0.1451
      9        [36m0.2153[0m       0.9450        [35m0.2134[0m  0.1548
     10        [36m0.2148[0m       0.9450        [35m0.2110[0m  0.1549
     11        0.2151       0.9450        0.2117  0.1456
     12        [36m0.2140[0m       0.9450        0.2120  0.1525
     13        0.2141 

{'concrete_qnn_clear_quantized': {'label': 'concrete_qnn_clear_quantized',
  'accuracy': 0.938,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'confusion_matrix': [[469, 0], [31, 0]],
  'roc_auc': 0.5021321961620469,
  'total_time_s': 0.06636285781860352,
  'num_samples': 500,
  'latency_per_sample_s': 0.00013272571563720703},
 'concrete_qnn_compile_time_s': 1.601952075958252,
 'concrete_qnn_fhe_simulate': {'label': 'concrete_qnn_fhe_simulate',
  'accuracy': 0.938,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'confusion_matrix': [[469, 0], [31, 0]],
  'roc_auc': 0.5021321961620469,
  'total_time_s': 4.317154169082642,
  'num_samples': 500,
  'latency_per_sample_s': 0.008634308338165282},
 'concrete_qnn_fhe_execute': {'label': 'concrete_qnn_fhe_execute_all_500',
  'accuracy': 0.938,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'confusion_matrix': [[469, 0], [31, 0]],
  'total_time_s': 418.5588130950928,
  'num_samples': 500,
  'latency_per_sample_s': 0.837117626190185

In [15]:
with open(outdir / "metrics_all.json", "w") as f:
    json.dump(results, f, indent=2)

print("[Summary]")
print(json.dumps({k: list(v.keys()) for k, v in results.items() if isinstance(v, dict)}, indent=2))
print(f"\nArtifacts in: {outdir.resolve()}")

[Summary]
{
  "logreg": [
    "logreg_sklearn_plaintext",
    "logreg_clear_quantized",
    "logreg_compile_time_s",
    "logreg_fhe_simulate",
    "logreg_fhe_execute"
  ],
  "linear_svm": [
    "linear_svm_sklearn_plaintext",
    "linear_svm_clear_quantized",
    "linear_svm_compile_time_s",
    "linear_svm_fhe_simulate",
    "linear_svm_fhe_execute"
  ],
  "decision_tree": [
    "decision_tree_sklearn_plaintext",
    "decision_tree_clear_quantized",
    "decision_tree_compile_time_s",
    "decision_tree_fhe_simulate",
    "decision_tree_fhe_execute"
  ],
  "random_forest": [
    "random_forest_sklearn_plaintext",
    "random_forest_clear_quantized",
    "random_forest_compile_time_s",
    "random_forest_fhe_simulate",
    "random_forest_fhe_execute"
  ],
  "xgboost": [
    "xgboost_sklearn_plaintext",
    "xgboost_clear_quantized",
    "xgboost_compile_time_s",
    "xgboost_fhe_simulate",
    "xgboost_fhe_execute"
  ],
  "concrete_qnn": [
    "concrete_qnn_clear_quantized",
    "con