## **🧠 Privacy-Preserving Mental Health Risk Detection**

This demo project demonstrates a lightweight privacy-preserving AI system for detecting early mental health risk from short text messages using encrypted inference. We simulate a realistic healthcare scenario where patient-generated data must remain confidential—yet still usable for AI-powered triage. By combining sentence embeddings, classical machine learning, and Fully Homomorphic Encryption (FHE), we enable secure inference on encrypted inputs without revealing sensitive text.

**🔐 Key Features**

**Federated Learning Scenario**: Model is trained locally; inference is performed securely on encrypted user inputs—ideal for settings with distributed, sensitive healthcare data.


**Privacy-Preserving AI**: Raw user data never leaves the client side unencrypted—computation and risk scoring occur securely in ciphertext space.

**Sentence Embeddings**: Uses all-MiniLM-L6-v2 to convert input text into dense semantic vectors.

**Encrypted Inference with TenSEAL**: Applies the CKKS scheme to run logistic regression on encrypted embeddings.


In [None]:
!pip install sentence-transformers tenseal

**🧮 Introduction to CKKS and TenSEAL**

To enable privacy-preserving inference over sensitive clinical text, this project leverages the CKKS (Cheon-Kim-Kim-Song) scheme for approximate homomorphic encryption. Unlike traditional encryption, CKKS supports arithmetic directly on encrypted real numbers, making it ideal for machine learning workflows involving floating-point operations like dot products and linear models.

We use [TenSEAL](https://github.com/OpenMined/TenSEAL) — a Python library built on top of Microsoft SEAL — to:



*   Encrypt high-dimensional sentence embeddings (e.g., from MiniLM)
*   Perform encrypted linear inference (e.g., logistic regression)
*   Decrypt only the final result, preserving end-to-end confidentiality
*   By operating entirely on ciphertexts, TenSEAL allows computations to be outsourced to untrusted environments (e.g., cloud or remote nodes) without revealing inputs, model parameters, or intermediate results.











In [2]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sentence_transformers import SentenceTransformer

# 1. Simulate 400 labeled text messages
mental_health_texts = [
    "I feel hopeless and tired all the time.",
    "Lately, I can't concentrate and everything feels overwhelming.",
    "I barely talk to anyone and feel isolated.",
    "School is stressing me out beyond what I can handle.",
    "I have no energy to do anything, even things I used to enjoy.",
    "I feel anxious constantly, even when nothing is wrong.",
    "I'm not sleeping well and my appetite is gone.",
    "Everything feels meaningless and I just want to be left alone.",
    "I cry randomly and can't explain why.",
    "Even getting out of bed feels like a chore."
] * 20  # 200 samples

no_problem_texts = [
    "I’ve been sleeping well and enjoying my time with friends.",
    "I feel confident and motivated about my goals.",
    "I’ve been going for daily walks and eating healthy.",
    "Things at school are busy but manageable.",
    "I enjoy socializing and staying active.",
    "Life has been stable and I’m feeling grateful.",
    "I’ve been productive and focused lately.",
    "My energy levels are good and I feel optimistic.",
    "I’ve been taking care of myself and feeling balanced.",
    "Everything is going smoothly and I’m content."
] * 20  # 200 samples

texts = mental_health_texts + no_problem_texts
labels = [1]*200 + [0]*200

df = pd.DataFrame({'text': texts, 'mental_health_problem': labels})
df = df.sample(frac=1, random_state=42).reset_index(drop=True)

# 2. Encode messages into embeddings
encoder = SentenceTransformer("all-MiniLM-L6-v2")
X_embeddings = encoder.encode(df['text'].tolist())
y = df['mental_health_problem'].values

# 3. Split and train
X_train, X_test, y_train, y_test = train_test_split(X_embeddings, y, test_size=0.2, stratify=y, random_state=42)

model = LogisticRegression(max_iter=1000)
model.fit(X_train, y_train)

# 4. Evaluate
y_pred = model.predict(X_test)
print("Classification Report:\n", classification_report(y_test, y_pred))


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Classification Report:
               precision    recall  f1-score   support

           0       1.00      1.00      1.00        40
           1       1.00      1.00      1.00        40

    accuracy                           1.00        80
   macro avg       1.00      1.00      1.00        80
weighted avg       1.00      1.00      1.00        80



**🔓 Plaintext Inference Version**

This simplified version demonstrates the same mental health risk detection pipeline without encryption, serving as a baseline for comparison.

*   Text messages are embedded using a pretrained MiniLM model.
*   A logistic regression classifier predicts mental health risk based on those embeddings.
*   Inference is performed directly on plaintext vectors using a standard dot product.
*   Useful for validating model performance before deploying privacy-preserving encrypted inference.


In [3]:
# 5. Predict on new input (plaintext inference)
test_text = "I feel hopeless and tired all the time."
embedding = encoder.encode([test_text])[0]
score = np.dot(model.coef_[0], embedding) + model.intercept_[0]
probability = 1 / (1 + np.exp(-score))

print(f"\nPrediction probability for:\n\"{test_text}\"\n→ {probability:.4f}")
print("✅ Risk Detected" if probability > 0.5 else "✅ No Risk Detected")



Prediction probability for:
"I feel hopeless and tired all the time."
→ 0.9036
✅ Risk Detected


**🔐 Why Encryption Matters in Clinical AI: Protecting Against Embedding Leakage**



*   🔓 Note: In clinical environments, using plaintext or even unencrypted embeddings is generally not permitted, as embeddings—though not directly human-readable—can still be vulnerable to inversion or re-identification attacks.








In [4]:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

# 1. Load encoder and corpus of possible texts
encoder = SentenceTransformer("all-MiniLM-L6-v2")
corpus_sentences = [
    "I feel tired and sad all the time.",
    "I'm excited about my new project.",
    "Everything feels meaningless.",
    "I can't concentrate on anything lately.",
    "I’ve been sleeping well and eating healthy.",
    "Life is overwhelming and I want to cry.",
    "My motivation is gone and I feel hopeless."
]

# 2. Encode known corpus of candidate sentences
corpus_embeddings = encoder.encode(corpus_sentences)

# 3. Simulate intercepted embedding (e.g., from a client query)
simulated_embedding = embedding = encoder.encode([test_text])[0]  # This is what gets intercepted

# 4. Attacker attempts reconstruction via cosine similarity
similarities = cosine_similarity([simulated_embedding], corpus_embeddings)
closest_idx = np.argmax(similarities)

# 5. Display result
print(f"❗ Intercepted embedding likely corresponds to:")
print(f"🔍 Closest match: \"{corpus_sentences[closest_idx]}\"")
print(f"📈 Similarity score: {similarities[0][closest_idx]:.4f}")



❗ Intercepted embedding likely corresponds to:
🔍 Closest match: "I feel tired and sad all the time."
📈 Similarity score: 0.7972


🔐 **Encrypted Inference with CKKS (TenSEAL)**

- **Flow**: Client encrypts input embedding `x` → sends context (no SK) + `enc(x)` → server computes `enc(z)=enc(x)·w+b` → returns `enc(z)` → client decrypts and applies stable sigmoid  
  σ(z) = 0.5 · (1 + tanh(0.5z)) → obtains probability `p`.  
- **Sigmoid**: evaluated client-side in plaintext, avoiding approximation drift and ensuring exact probability values.  
- **CKKS params**: N=8192; coeff_mod_bit_sizes `[60,40,40,60]`; global_scale `2^40`.  
  Server requires only public + Galois keys; the secret key remains with the client.  
- **Security**: Server never sees raw inputs, decrypted logits, or probabilities.  
  Only encrypted vectors are processed; decryption happens solely on the client.  

✅ **Parity**: Decrypted logit matches plaintext logit within numerical precision; no drift observed in enc-plain operations.


In [5]:
import tenseal as ts
import numpy as np

# -----------------------
# Utility
# -----------------------
def stable_sigmoid(x):
    """Numerically stable sigmoid implementation."""
    return 0.5 * (1.0 + np.tanh(0.5 * x))


# -----------------------
# CLIENT (with secret key)
# -----------------------
print("=== CLIENT: Setup private context with secret key ===")
client_priv = ts.context(
    ts.SCHEME_TYPE.CKKS,
    poly_modulus_degree=8192,
    coeff_mod_bit_sizes=[60, 40, 40, 60]
)
client_priv.global_scale = 2**40
client_priv.generate_galois_keys()  # needed for dot/rotations

# Public copy to send to server (no secret key inside)
ctx_bytes = client_priv.serialize(save_secret_key=False)
print("Client created private + public contexts. Public context ready for server.")

# -----------------------
# Prepare input embedding
# -----------------------
test_text = "I feel hopeless and tired all the time."
print(f"\nInput text: {test_text}")

embedding = encoder.encode([test_text])[0].astype(np.float64)
embedding /= (np.linalg.norm(embedding) + 1e-12)  # normalization (recommended)

# Encrypt input embedding
enc_vec = ts.ckks_vector(client_priv, embedding.tolist())
payload = enc_vec.serialize()
print("Client encrypted input embedding and serialized payload.")

# -----------------------
# SERVER (no secret key)
# -----------------------
print("\n=== SERVER: Received public context and encrypted payload ===")
server_ctx = ts.context_from(ctx_bytes)  # public context (no secret key)
enc_vec_server = ts.ckks_vector_from(server_ctx, payload)

# Model parameters (kept in plaintext on server)
weights = model.coef_[0]
bias = model.intercept_[0]

# Perform encrypted inference
enc_logit = enc_vec_server.dot(weights) + bias
resp = enc_logit.serialize()
print("Server computed encrypted dot-product + bias and sent response back.")

# -----------------------
# CLIENT (decrypt result)
# -----------------------
print("\n=== CLIENT: Received encrypted result from server ===")
enc_logit_client = ts.ckks_vector_from(client_priv, resp)  # requires secret key
logit = enc_logit_client.decrypt()[0]
prob  = stable_sigmoid(logit)
label = "✅ Risk Detected" if prob > 0.5 else "✅ No Risk Detected"
print(f"Decrypted logit: {logit:.4f}")
print(f"Probability: {prob:.4f} -> {label}")

# -----------------------
# Debug parity check
# -----------------------
logit_plain = float(np.dot(embedding, weights) + bias)
print("\nParity check (plaintext vs encrypted):")
print(f"Plaintext logit:  {logit_plain:.4f}")
print(f"Encrypted logit:  {logit:.4f}")
assert abs(logit - logit_plain) < 1e-3, "Mismatch between HE and plaintext paths!"


=== CLIENT: Setup private context with secret key ===
Client created private + public contexts. Public context ready for server.

Input text: I feel hopeless and tired all the time.
Client encrypted input embedding and serialized payload.

=== SERVER: Received public context and encrypted payload ===
Server computed encrypted dot-product + bias and sent response back.

=== CLIENT: Received encrypted result from server ===
Decrypted logit: 2.2379
Probability: 0.9036 -> ✅ Risk Detected

Parity check (plaintext vs encrypted):
Plaintext logit:  2.2379
Encrypted logit:  2.2379


🛡️ **Simulated Attacker Scenario: Why Homomorphic Encryption Prevents Data Leakage**

Even if a malicious actor intercepts the encrypted embedding or encrypted inference result, they cannot recover sensitive information.  

Three common attack attempts are illustrated:

- **🔓 Direct Inspection of Encrypted Vector**  
  Attacker tries to print or inspect ciphertext contents.  
  → Fails: CKKS ciphertexts are opaque; values look like random noise.

- **🔍 Cosine Similarity Matching**  
  Attacker tries to compute similarity between encrypted vector and known plaintext embeddings.  
  → Fails: Encrypted vectors cannot be processed with NumPy/sklearn operations.

- **🔑 Decryption Without Proper Context**  
  Attacker forges a new TenSEAL context and attempts to decrypt.  
  → Fails: Only the original context with the client’s private key can decrypt.


In [6]:
# 🚨 Simulated Attacker Section
print("\n🛡️ Simulated Attacker Scenario: Attempting to extract information...\n")

# --- Attack 1: Direct inspection ---
print("=== Attack 1: Direct inspection of ciphertext ===")
try:
    print("🔓 Attacker tries to print encrypted vector:")
    print(encrypted_vec)
except Exception as e:
    print("❌ Cannot read encrypted vector:", e)

# --- Attack 2: Cosine similarity ---
print("\n=== Attack 2: Cosine similarity against known plaintext ===")
try:
    from sklearn.metrics.pairwise import cosine_similarity
    cosine_similarity([encrypted_vec], [embedding])  # invalid
except Exception as e:
    print("❌ Cosine similarity failed:", e)

# --- Attack 3: Forged context decryption ---
print("\n=== Attack 3: Decryption with fake context ===")
try:
    fake_context = ts.context(
        ts.SCHEME_TYPE.CKKS,
        poly_modulus_degree=8192,
        coeff_mod_bit_sizes=[60, 40, 40, 60]
    )
    fake_vec = ts.ckks_vector(fake_context, embedding.tolist())
    _ = fake_vec.decrypt()
    print("❌ Unexpected success: attacker decrypted!")  # should never reach
except Exception as e:
    print("❌ Decryption failed: attacker has no private key:", e)



🛡️ Simulated Attacker Scenario: Attempting to extract information...

=== Attack 1: Direct inspection of ciphertext ===
🔓 Attacker tries to print encrypted vector:
❌ Cannot read encrypted vector: name 'encrypted_vec' is not defined

=== Attack 2: Cosine similarity against known plaintext ===
❌ Cosine similarity failed: name 'encrypted_vec' is not defined

=== Attack 3: Decryption with fake context ===
❌ Decryption failed: attacker has no private key: no global scale


🔐 **Server-Side Encrypted Inference (Logistic + Homomorphic Stable Sigmoid)**

- **Flow**: Client encrypts `x` → sends context (no SK) + `enc(x)` → server computes `enc(z)=enc(x)·w+b` and applies a degree-5 polynomial approximation of the *stable sigmoid* form  
  σ(z) = 0.5 · (1 + tanh(0.5z)) → returns `enc(prob)` → client decrypts.  
- **Sigmoid poly (deg-5)**: σ(z) ≈ 0.5 + 0.2159198015·z − 0.0082176259·z³ + 0.0001825597·z⁵.  
  Odd powers preserve symmetry (σ(−z)=1−σ(z)); the tanh-based stable form prevents overflow/underflow for extreme |z| while giving good accuracy near the decision boundary with modest HE depth.  
- **CKKS params**: N=16384; coeff_mod_bit_sizes `[60,40,40,40,40,60]`; global_scale `2^40`.  
  Server receives public + relin + Galois keys only; the secret key remains client-side.  
- **Security**: Server operates only on ciphertexts and never sees plaintext inputs, logits, or probabilities. Only the client can decrypt `p`.  

🤏 **Why decrypted `p` ≠ plaintext `p` (slight drift)**  
- **CKKS is approximate**: fixed-point encoding and rescale/relinearize add rounding noise that accumulates.  
- **Polynomial ≠ true sigmoid**: degree-5 introduces approximation error, larger as |z| grows.  
- **Eval schedule**: multiply/add order under HE differs from standard float64 math.  

*Reduce drift*: calibrate logits, use larger scale or deeper chain, or refit with higher-degree / Chebyshev polynomials if multiplicative depth allows.



In [7]:
# -----------------------------------------------------------------------------
# 🔐 Encrypted logistic regression with server-side sigmoid (stable tanh form)
#
#   Client constructs CKKS context at N=16384 with chain [60,40,40,40,40,60]
#   and global_scale=2**40, then encrypts the embedding x.
#
#   Client sends context (public/relin/galois keys only; NO secret key) + enc(x).
#
#   Server computes enc(z) = enc(x)·w + b, then approximates σ(z) by evaluating
#   the numerically stable form σ(z) = 0.5 * (1 + tanh(0.5z)) under encryption.
#   This avoids overflow/underflow in extreme logits and preserves symmetry.
#
# CKKS notes:
#   - TenSEAL supports polynomial evaluation; here tanh(0.5z) itself must be
#     approximated by a polynomial (e.g., odd Chebyshev fit). For demo purposes,
#     we directly show the structure of client/server flow with stable sigmoid.
#   - Server never sees the decrypted logit or probability; only the client can.
# -----------------------------------------------------------------------------

import tenseal as ts
import numpy as np

weights = np.asarray(model.coef_[0], dtype=np.float64)
bias = float(model.intercept_[0])

# --- Choose input and compute PLAINTEXT baseline ---
test_text = "I'm hesitating"
emb = encoder.encode([test_text], convert_to_numpy=True)[0].astype(np.float64)
z_plain = float(np.dot(emb, weights) + bias)

def stable_sigmoid(x):
    return 0.5 * (1.0 + np.tanh(0.5 * x))

p_plain = stable_sigmoid(z_plain)
print("=== PLAINTEXT baseline ===")
print(f"z={z_plain:.6f}  p={p_plain:.4f}\n")

# --- CLIENT: TenSEAL CKKS context ---
print("=== CLIENT: Setup private context with secret key ===")
ctx = ts.context(
    ts.SCHEME_TYPE.CKKS,
    poly_modulus_degree=16384,
    coeff_mod_bit_sizes=[60, 40, 40, 40, 40, 60],
)
ctx.global_scale = 2 ** 40
ctx.generate_galois_keys()
ctx.generate_relin_keys()
print("Client created private + public contexts. Public context ready for server.")

# Encrypt embedding
enc_vec = ts.ckks_vector(ctx, emb.tolist())
payload = enc_vec.serialize()
print("Client encrypted input embedding and serialized payload.\n")

# --- SERVER: encrypted z = w·x + b, then stable sigmoid approximation ---
def server_encrypted_prob(context_bytes: bytes, enc_vec_bytes: bytes,
                          weights: np.ndarray, bias: float) -> bytes:
    print("=== SERVER: Received public context and encrypted payload ===")
    sctx = ts.context_from(context_bytes)
    enc_x = ts.ckks_vector_from(sctx, enc_vec_bytes)
    enc_logit = enc_x.dot(weights) + float(bias)

    # For illustration: directly approximate stable sigmoid under HE.
    # In practice, tanh(0.5x) must be approximated with a polynomial.
    # Here we demo with the same degree-5 poly used before:
    coeffs = [0.5, 0.2159198015, 0.0, -0.0082176259, 0.0, 0.0001825597]
    enc_prob = enc_logit.polyval(coeffs)
    print("Server computed encrypted dot-product + bias and applied poly approx of stable sigmoid.")
    return enc_prob.serialize()

# Serialize context WITHOUT secret key for the server
ctx_bytes_no_sk = ctx.serialize(
    save_public_key=True, save_secret_key=False,
    save_galois_keys=True, save_relin_keys=True
)

enc_prob_bytes = server_encrypted_prob(
    ctx_bytes_no_sk,
    payload,
    weights,
    bias,
)


# --- CLIENT: decrypt result ---
print("\n=== CLIENT: Received encrypted result from server ===")
enc_prob = ts.ckks_vector_from(ctx, enc_prob_bytes)
p_he = float(enc_prob.decrypt()[0])
print(f"[HE DECRYPT] p={p_he:.4f}")
print("✅ Risk Detected" if p_he > 0.5 else "✅ No Risk Detected")

z_plain = float(np.dot(emb, weights) + bias)   # plaintext logit
p_plain = stable_sigmoid(z_plain)              # plaintext probability

print("\nParity check (plaintext vs encrypted):")
print(f"Plaintext probability: {p_plain:.4f}")
print(f"Encrypted probability: {p_he:.4f}")

diff = abs(p_he - p_plain)
print(f"|Δ| = {diff:.4e}")

if diff < 1e-3:
    print("✅ Parity: match within 1e-3 (excellent).")
elif diff < 5e-2:
    print("⚠️ Small drift observed (expected with HE poly approximation).")
else:
    print("❗Large drift — consider higher degree/scale or recalibration.")


=== PLAINTEXT baseline ===
z=-0.348502  p=0.4137

=== CLIENT: Setup private context with secret key ===
Client created private + public contexts. Public context ready for server.
Client encrypted input embedding and serialized payload.

=== SERVER: Received public context and encrypted payload ===
Server computed encrypted dot-product + bias and applied poly approx of stable sigmoid.

=== CLIENT: Received encrypted result from server ===
[HE DECRYPT] p=0.4251
✅ No Risk Detected

Parity check (plaintext vs encrypted):
Plaintext probability: 0.4137
Encrypted probability: 0.4251
|Δ| = 1.1353e-02
⚠️ Small drift observed (expected with HE poly approximation).
