In [1]:
from __future__ import annotations
import json
import os
import pickle
from pathlib import Path
from typing import Dict, Any, List, Optional
import numpy as np
import pandas as pd
from concrete.ml.deployment.fhe_client_server import FHEModelClient

REPO_DIR = Path.cwd()
DEPLOYMENT_DIR = REPO_DIR / "deployment_files" / "model"     # report.json + client.zip + preprocessors
REPORT_PATH = DEPLOYMENT_DIR / "report.json"

CLIENT_FILES = (REPO_DIR / "client_files").resolve()
CLIENT_FILES.mkdir(parents=True, exist_ok=True)

FHE_KEYS_DIR = (REPO_DIR / ".fhe_keys").resolve()
FHE_KEYS_DIR.mkdir(parents=True, exist_ok=True)

In [2]:
# --------- Small helpers ---------------------
def _short_hex(b: bytes, n: int = 500, shift: int = 100) -> str:
    return b[shift:shift + n].hex()

def _ensure_dir(p: Path) -> Path:
    p.mkdir(parents=True, exist_ok=True)
    return p

def load_json(p: Path) -> Dict[str, Any]:
    if not p.exists():
        raise FileNotFoundError(f"Missing JSON at {p}")
    return json.loads(p.read_text())

def load_pkl(p: Path):
    if not p.exists():
        raise FileNotFoundError(f"Missing pickle at {p}")
    with p.open("rb") as f:
        return pickle.load(f)
    
def to_bool(x) -> float:
    """Return 1.0/0.0 from diverse truthy/falsey tokens; raise on ambiguity."""
    if isinstance(x, (bool, np.bool_)):
        return 1.0 if bool(x) else 0.0
    s = str(x).strip().lower()
    if s in {"true","t","1","yes","y"}:  return 1.0
    if s in {"false","f","0","no","n"}:  return 0.0
    # numeric fallback (explicit)
    try:
        return 0.0 if float(s) == 0.0 else 1.0
    except Exception:
        raise ValueError(f"Boolean field not parseable: {x!r}")

In [3]:

# --------- Multi-input client (pads then encrypts one slice) ------------------
class MultiInputsFHEModelClient(FHEModelClient):
    def __init__(self, path_dir: Path, key_dir: Path, nb_inputs: int):
        self.nb_inputs = nb_inputs
        super().__init__(path_dir, key_dir=key_dir)

    @property
    def total_width(self) -> int:
        if not hasattr(self.model, "input_quantizers") or self.model.input_quantizers is None:
            raise RuntimeError("Model has no input_quantizers; ensure it was compiled.")
        return len(self.model.input_quantizers)

    def quantize_encrypt_slice(self, x_slice: np.ndarray, input_index: int, party_slice: slice) -> bytes:
        if x_slice.ndim != 2 or x_slice.shape[0] != 1:
            raise ValueError(f"x_slice must be (1, n_cols); got {x_slice.shape}")

        total = self.total_width
        start, stop = party_slice.start, party_slice.stop
        if start is None or stop is None:
            raise ValueError("party_slice must have start/stop.")
        if stop > total:
            raise ValueError(f"party_slice.stop={stop} exceeds model width {total}")
        if (stop - start) != x_slice.shape[1]:
            raise ValueError(f"Slice width mismatch: expects {stop-start}, got {x_slice.shape[1]}")

        pad = np.zeros((1, total), dtype=float)
        pad[:, party_slice] = x_slice

        q_full = self.model.quantize_input(pad)   # int64 quantized full vector
        q_slice = q_full[:, party_slice]

        q_inputs = [None] * self.nb_inputs
        q_inputs[input_index] = q_slice

        enc_tuple = self.client.encrypt(*q_inputs)
        return enc_tuple[input_index].serialize()


In [4]:

# --------- Verification using report["features"] ------------------------------
def verify_party_alignment(report: Dict[str, Any], party: str, preproc) -> slice:
    """Check that the current preprocessor names match saved names, return saved slice."""
    features = report["features"]
    per_party = features["per_party_features"]
    if party not in per_party:
        raise ValueError(f"Party '{party}' missing in features.per_party_features.")

    saved_names = per_party[party]["feature_names_out"]
    saved_start, saved_stop = per_party[party]["slice"]

    cols = (
        report["group_columns"][party]["numeric"]
        + report["group_columns"][party]["categorical"]
        + report["group_columns"][party]["boolean"]
    )
    try:
        names_now = preproc.get_feature_names_out().tolist()
    except Exception:
        # Fallback: transform a NA row to get width only, synthesize names
        df = pd.DataFrame([{c: np.nan for c in cols}], columns=cols)
        w = int(preproc.transform(df).shape[1])
        names_now = [f"col_{i}" for i in range(w)]

    if len(saved_names) != len(names_now) or saved_names != names_now:
        raise RuntimeError(
            f"[{party}] post-processed feature order drift detected.\n"
            f"Saved width={len(saved_names)}, now={len(names_now)}.\n"
            f"First few saved: {saved_names[:8]}\nFirst few now  : {names_now[:8]}"
        )
    return slice(saved_start, saved_stop)

def preprocess_encrypt_send_party(
    client_id: str,
    party: str,
    raw_row: Dict[str, Any],
    *,
    deployment_dir: Path = DEPLOYMENT_DIR,
) -> str:
    """
    Require ALL features for the party to be present. Refuse if any missing or extras.
    """
    report = load_json(deployment_dir / "report.json")
    parties = report["features"]["party_order"]
    if party not in parties:
        raise ValueError(f"Party '{party}' not in party_order: {parties}")

    nums = report["group_columns"][party]["numeric"]
    cats = report["group_columns"][party]["categorical"]
    bools = report["group_columns"][party]["boolean"]
    expected = nums + cats + bools

    # Check for missing/extras BEFORE coercion
    provided = list(raw_row.keys())
    missing = [c for c in expected if c not in raw_row]
    extras  = [c for c in provided if c not in expected]
    if missing or extras:
        msg = []
        if missing: msg.append(f"missing={missing}")
        if extras:  msg.append(f"unexpected={extras}")
        raise ValueError(f"[{party}] strict input check failed: " + "; ".join(msg))

    # Coerce types strictly; raise if any field cannot be parsed
    coerced: Dict[str, Any] = {}
    # numeric → float (no NaN allowed here)
    for c in nums:
        v = raw_row[c]
        try:
            coerced[c] = float(v)
        except Exception:
            raise ValueError(f"[{party}] numeric field '{c}' not parseable: {v!r}")
    # categorical → string (empty disallowed)
    for c in cats:
        v = raw_row[c]
        s = str(v).strip()
        if s == "":
            raise ValueError(f"[{party}] categorical field '{c}' is empty")
        coerced[c] = s
    # boolean → float in {0.0, 1.0}
    for c in bools:
        coerced[c] = to_bool(raw_row[c])

    # Load preprocessor and verify alignment/slice
    preproc = load_pkl(Path(report["preprocessors"][party]))
    party_slice = verify_party_alignment(report, party, preproc)

    # Preprocess exactly in training column order
    df_raw = pd.DataFrame([coerced], columns=expected)
    x_post = preproc.transform(df_raw)
    x_post = np.asarray(x_post, dtype=float)

    # Encrypt only this party slice
    key_dir = _ensure_dir(FHE_KEYS_DIR / client_id)
    client = MultiInputsFHEModelClient(deployment_dir, key_dir=key_dir, nb_inputs=len(parties))
    enc_bytes = client.quantize_encrypt_slice(
        x_slice=x_post,
        input_index=parties.index(party),
        party_slice=party_slice,
    )

    out_dir = _ensure_dir(CLIENT_FILES / client_id)
    out_path = out_dir / f"encrypted_inputs_{party}"
    with out_path.open("wb") as f:
        f.write(enc_bytes)

    print(
        f"[{party}] client_id={client_id} → {out_path.name} ({out_path.stat().st_size} bytes) | "
        f"slice={party_slice.start}:{party_slice.stop} (total={client.total_width})"
    )
    return _short_hex(enc_bytes)


In [5]:
def sample_rows_from_report(report: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
    gc = report["group_columns"]
    samples: Dict[str, Dict[str, Any]] = {}

    samples["agritech"] = {
        "farm_area_ha": 1.6,
        "input_cost_kes": 45000,
        "agritech_score": 0.72,
        "mpesa_txn_count_90d": 42,
        "mpesa_inflow_kes_90d": 150000,
        "eo_ndvi_gs": 0.51,
        "crop_primary": "maize",
        "crop_secondary": "beans",
        "irrigated": 0,
    }
    samples["bank"] = {
        "loan_amount_kes": 250000,
        "tenor_months": 12,
        "interest_rate_pct": 14.0,
        "prior_default": 0,
    }
    samples["processor"] = {
        "yield_t_ha": 2.2,
        "sales_kes": 180000,
        "processor_contract": 1,
    }
    samples["insurance"] = {
        "climate_risk_index": 0.33,
        "insured": 1,
    }
    samples["government"] = {
        "rain_mm_gs": 520.0,
        "soil_quality_index": 0.58,
        "county": "Nakuru",
        "gov_subsidy": 0,
    }
    return samples

In [6]:
report = load_json(REPORT_PATH)
client_id = None # 4091376614
# If client_id is not specified, pick the first numerical subdirectory in CLIENT_FILES
if 'client_id' not in locals() or not client_id:
    subdirs = [d for d in os.listdir(CLIENT_FILES) if (CLIENT_FILES / d).is_dir() and d.isdigit()]
    if subdirs:
        client_id = subdirs[0]
    else:
        raise ValueError("No numerical client_id subdirectory found in CLIENT_FILES.")

previews: Dict[str, str] = {}

rows = sample_rows_from_report(report)
for party in report["features"]["party_order"]:
    previews[party] = preprocess_encrypt_send_party(
        client_id=client_id,
        party=party,
        raw_row=rows[party],
    )

# bank_row = {
#     "loan_amount_kes": "5675",
#     "tenor_months": 12,
#     "interest_rate_pct": 14.0,
#     "prior_default": 0,
# }
# previews_bank = preprocess_encrypt_send_party(
#     client_id=client_id, 
#     party="bank", 
#     raw_row=bank_row
# )
# previews["bank (alt)"] = previews_bank

print("\nHex previews (truncated):")
for p, h in previews.items():
    print(f"  {p:10s}: {h[:64]}...")

[agritech] client_id=427684149 → encrypted_inputs_agritech (508416 bytes) | slice=0:31 (total=90)
[bank] client_id=427684149 → encrypted_inputs_bank (65832 bytes) | slice=31:35 (total=90)
[processor] client_id=427684149 → encrypted_inputs_processor (49440 bytes) | slice=35:38 (total=90)
[insurance] client_id=427684149 → encrypted_inputs_insurance (33048 bytes) | slice=38:40 (total=90)
[government] client_id=427684149 → encrypted_inputs_government (819864 bytes) | slice=40:90 (total=90)

Hex previews (truncated):
  agritech  : 1698b8b644954a5eb058774c6952bd96a19a41b8d873abb6b0d7cdde28f3d8e9...
  bank      : 99dcda3aa921a9115c2163b7a38b50ff454b166ae4c05244c36682af2180b91e...
  processor : 17af52b72d9cc98927027c6ab8c341fd57a1ce18c034877481fceebace32e984...
  insurance : f31498a6164f8c6fd83bb9197445464ae360c5bd9a204a115aceed3bb1fc7729...
  government: feea646c33f30bebbbfcc751532bde3b3701acbd0c6b490f82dd552bfd71559b...
