# Weak Supervision with Snorkel AI

Model individuals and businesses separately.

Create labels for unlabeled data with Snorkel AI weak supervision. Train model on newly labeled data. Validate model on labeled data. Generate explanation with Snorkel labeling functions and SHAPLEY.

In [1]:
import pandas as pd
import numpy as np
import os
# Snorkel AI weak supervision
from snorkel.labeling import LFAnalysis
from snorkel.labeling import labeling_function, PandasLFApplier
from snorkel.labeling.model import LabelModel
# classification model
from sklearn.preprocessing import StandardScaler
import xgboost as xgb
# SHAP explainability
import shap
# save model
import pickle
# evaluation metrics
from sklearn.metrics import (
    confusion_matrix, 
    accuracy_score, 
    precision_score, 
    recall_score,
    f1_score,
    roc_auc_score,
    classification_report
)

# directories
RAW_DIR = 'data/raw'
PROCESSED_DIR = 'data/processed'
CHECKPOINT_DIR = 'checkpoint'
FINAL_DIR = 'data/final'

# weak label constants
ABSTAIN = -1  # Uncertain
NEGATIVE = 0  # Low risk
POSITIVE = 1  # High risk

## Model for Businesses

### Part 1: Define Snorkel Labeling Functions LFs (Businesses)

In [2]:
@labeling_function()
def short_hold_time(x):
    """Short hold time in high risk industries"""
    condition = (
        x.industry_code in ("4561", "7761", "7499", "7214", "711", "4113", "919", "7599", "7292", "7215") and 
        x.median_hold_time_funds <= 26.0
    )
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def high_txn_volume_low_sales(x):
    """High total transaction volume despite low sales in high risk industries"""
    condition = (
        x.industry_code in ("4561", "7761", "7499", "7214", "711", "4113", "919", "7599", "7292", "7215") and 
        x.sales <= 14515.0 and 
        x.sum_amt_total >= 209466.1
    )
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def low_channel_diversification(x):
    """High transaction count with low channel diversification in high risk industries"""
    condition = (
        x.industry_code in ("4561", "7761", "7499", "7214", "711", "4113", "919", "7599", "7292", "7215") and 
        x.max_channel_share <= 0.820 and 
        x.n_txn_total >= 70.0
    )
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def transaction_same_amount(x):
    """Suspicious transaction patterns (repetitive amounts or high e-commerce ratio) in high risk industries"""
    condition = (
        x.industry_code in ("4561", "7761", "7499", "7214", "711", "4113", "919", "7599", "7292", "7215") and 
        (x.transaction_same_amount_frequency_90d >= 0.013 or x.transaction_ecommerce_ratio >= 0.087)
    )
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def infrequent_transaction(x):
    """Low transaction count (likely inactive/low activity)"""
    condition = x.n_txn_total <= 8.0
    return NEGATIVE if condition else ABSTAIN

@labeling_function()
def long_hold_time(x):
    """Funds held for a long period (slow movement of money)"""
    condition = x.median_hold_time_funds > 172.1
    return NEGATIVE if condition else ABSTAIN

@labeling_function()
def low_debit_transfers(x):
    """Low debit activity (low money going out)"""
    condition = x.channel_debit_volume <= 2690.3
    return NEGATIVE if condition else ABSTAIN

@labeling_function()
def no_recent_transaction(x):
    """No recent transaction (inactive account)"""
    condition = x.days_since_last_transaction > 19.0
    return NEGATIVE if condition else ABSTAIN

@labeling_function()
def low_var_transaction(x):
    """Low variability in transaction amounts across specific channels (consistent/predictable usage)"""
    condition = (x.cv_amt_emt <= 1.087 and x.cv_amt_eft <= 1.8 and x.std_amt_card <= 152.6)
    return NEGATIVE if condition else ABSTAIN

@labeling_function()
def high_var_wire(x):
    """Use of wire transfers in high risk industries"""
    condition = (
        x.industry_code in ("4561", "7761", "7499", "7214", "711", "4113", "919", "7599", "7292", "7215") and 
        x.std_amt_wire > 73018.4
    )
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def high_sale_high_debit(x):
    """High sales paired with high debit activity in high risk industries"""
    condition = (
        x.industry_code in ("4561", "7761", "7499", "7214", "711", "4113", "919", "7599", "7292", "7215") and 
        x.sales >= 98725.8 and x.channel_debit_volume >= 126121.7
    )
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def frequent_transaction_short_hold_time(x):
    """High transaction frequency with short hold times in high risk industries"""
    condition = (
        x.industry_code in ("4561", "7761", "7499", "7214", "711", "4113", "919", "7599", "7292", "7215") and 
        x.transaction_frequency_7d >= 6.0 and 
        x.median_hold_time_funds <= 33.356
    )
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def high_var_emt(x):
    """High variability in EMT amounts combined with other factors (variability in cards/EFT and high risk industries)"""
    condition = (
        x.industry_code in ("4561", "7761", "7499", "7214", "711", "4113", "919", "7599", "7292", "7215") and 
        (x.cv_amt_emt >= 1.538 or x.std_amt_card >= 817.522 or x.cv_amt_eft >= 2.956)
    )
    return POSITIVE if condition else ABSTAIN

In [3]:
LF_DEFINITIONS = [
    ("Short hold time", short_hold_time),
    ("High transaction volume and low sales", high_txn_volume_low_sales),
    ("Low channel diversification", low_channel_diversification),
    ("Many transactions of the same amount", transaction_same_amount),
    ("Infrequent transaction", infrequent_transaction),
    ("Long hold time", long_hold_time),
    ("Low debit transfers", low_debit_transfers),
    ("No recent transaction", no_recent_transaction),
    ("Low variability in transaction amounts", low_var_transaction),
    ("High variability in wire transfers", high_var_wire),
    ("High sale and high debit transfers", high_sale_high_debit),
    ("Frequent transaction and short hold time", frequent_transaction_short_hold_time),
    ("High variability in EMT transfers", high_var_emt)
]

# Separate convenience lists used throughout the notebook
lf_names = [name for name, _ in LF_DEFINITIONS]
lfs = [lf for _, lf in LF_DEFINITIONS]

### Part 2: LF functions Sanity Check (Businesses)

Check the percentage of samples that satisfy each LF

In [4]:
# read labeled data
df_labeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'business_labeled.csv'))

# Apply LFs to golden_eval
L_dev = np.array([[lf(row) for lf in lfs] for _, row in df_labeled.iterrows()])
lf_analysis = LFAnalysis(L_dev, lfs=lfs).lf_summary()
print(lf_analysis)

                                       j Polarity  Coverage  Overlaps  \
short_hold_time                        0      [1]  0.079545  0.079545   
high_txn_volume_low_sales              1       []  0.000000  0.000000   
low_channel_diversification            2      [1]  0.022727  0.022727   
transaction_same_amount                3      [1]  0.045455  0.045455   
infrequent_transaction                 4      [0]  0.204545  0.204545   
long_hold_time                         5      [0]  0.204545  0.136364   
low_debit_transfers                    6      [0]  0.204545  0.181818   
no_recent_transaction                  7      [0]  0.136364  0.136364   
low_var_transaction                    8      [0]  0.454545  0.295455   
high_var_wire                          9       []  0.000000  0.000000   
high_sale_high_debit                  10      [1]  0.011364  0.011364   
frequent_transaction_short_hold_time  11      [1]  0.011364  0.011364   
high_var_emt                          12      [1]  

In [5]:
# read unlabeled data
df_unlabeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'business_unlabeled.csv'))

# apply LFs to golden_eval
L_dev = np.array([[lf(row) for lf in lfs] for _, row in df_unlabeled.iterrows()])
lf_analysis = LFAnalysis(L_dev, lfs=lfs).lf_summary()
print(lf_analysis)

                                       j Polarity  Coverage  Overlaps  \
short_hold_time                        0      [1]  0.062994  0.062021   
high_txn_volume_low_sales              1      [1]  0.022133  0.020795   
low_channel_diversification            2      [1]  0.020066  0.018728   
transaction_same_amount                3      [1]  0.034051  0.028943   
infrequent_transaction                 4      [0]  0.190441  0.188009   
long_hold_time                         5      [0]  0.179861  0.137906   
low_debit_transfers                    6      [0]  0.149459  0.141919   
no_recent_transaction                  7      [0]  0.093640  0.092302   
low_var_transaction                    8      [0]  0.361182  0.236045   
high_var_wire                          9      [1]  0.001459  0.001338   
high_sale_high_debit                  10      [1]  0.010094  0.009121   
frequent_transaction_short_hold_time  11      [1]  0.031254  0.031011   
high_var_emt                          12      [1]  

### Part 3: Train Snorkel's Generative Model (Businesses)

Snorkel learns to combine LFs into probabilistic labels

In [6]:
# read unlabeled data
df_unlabeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'business_unlabeled.csv'))

# apply labeling functions to unlabeled data
applier = PandasLFApplier(lfs)
L_unlabeled = applier.apply(df_unlabeled)

100%|██████████| 8223/8223 [00:01<00:00, 7689.72it/s]


In [49]:
# fit Snorkel AI weak supervision model on unlabeled data
label_model = LabelModel(cardinality=2, verbose=True)
label_model.fit(L_unlabeled, n_epochs=50, log_freq=10, seed=1)

# save Snorkel model
label_model_path = os.path.join(CHECKPOINT_DIR, 'snorkel_label_bsn.pkl')
with open(label_model_path, 'wb') as f:
    pickle.dump(label_model, f)

INFO:root:Computing O...


INFO:root:Estimating \mu...
  0%|          | 0/50 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.298]
INFO:root:[10 epochs]: TRAIN:[loss=0.153]
INFO:root:[20 epochs]: TRAIN:[loss=0.037]
INFO:root:[30 epochs]: TRAIN:[loss=0.046]
INFO:root:[40 epochs]: TRAIN:[loss=0.028]
100%|██████████| 50/50 [00:00<00:00, 924.06epoch/s]
INFO:root:Finished Training


In [7]:
# read Snorkel model
label_model_path = os.path.join(CHECKPOINT_DIR, 'snorkel_label_bsn.pkl')
with open(label_model_path, 'rb') as f:
    label_model = pickle.load(f)

# predict labels for unlabeled data
pred_probs_unlabeled = label_model.predict_proba(L_unlabeled)
pred_unlabeled = label_model.predict(L_unlabeled)
# all predicted 1s are high-risk, while all predicted 0s or abstains are low-risk
pred_unlabeled = np.where(pred_unlabeled==1, 1, 0)

### Step 4: Train Final Classification Model (Businesses)

Train classification model on newly labeled data

In [8]:
# read unlabeled data
df_unlabeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'business_feature_unlabeled.csv'))

# continuous features
features_continuous = [c for c in df_unlabeled if c != 'label' and len(df_unlabeled[c].unique()) > 2]
df_continuous = df_unlabeled[features_continuous].copy()
# fit normalization scaler
scaler = StandardScaler()
scaler.fit(df_continuous)
# save normalization scaler
scaler_path = os.path.join(CHECKPOINT_DIR, 'scaler_bsn.pkl')
with open(scaler_path, 'wb') as f:
    pickle.dump(scaler, f)
# transform continuous data (z-scores)
X_train = scaler.transform(df_continuous)
X_train = pd.DataFrame(X_train, columns=df_continuous.columns)

# categorical features
features_categorical = [c for c in df_unlabeled if c != 'label' and len(df_unlabeled[c].unique()) <= 2]
df_categorical = df_unlabeled[features_categorical].copy()
X_train = pd.concat((X_train, df_categorical), axis=1)

# fill missing values with zeros
X_train = X_train.fillna(value=0)

# all features
features = features_continuous + features_categorical

# labels
y_train = pred_unlabeled

In [52]:
# train XGBoost (inherently interpretable with feature importance)
classifier = xgb.XGBClassifier(
    n_estimators=100,
    max_depth=4,
    learning_rate=0.1,
    random_state=1,
    eval_metric='logloss'
)
classifier.fit(X_train, y_train)

# save XGBoost model
classifier_path = os.path.join(CHECKPOINT_DIR, 'xgboost_bsn.pkl')
with open(classifier_path, 'wb') as f:
    pickle.dump(classifier, f)

### Step 5: Model Performance (Businesses)

Evaluate performance of classification model on labeled data

In [9]:
# read XGBoost model
classifier_path = os.path.join(CHECKPOINT_DIR, 'xgboost_bsn.pkl')
with open(classifier_path, 'rb') as f:
    classifier = pickle.load(f)

# read normalization scaler
scaler_path = os.path.join(CHECKPOINT_DIR, 'scaler_bsn.pkl')
with open(scaler_path, 'rb') as f:
    scaler = pickle.load(f)

# read labeled data
df_labeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'business_feature_labeled.csv'))

# normalize continuous features (z-scores)
df_continuous = df_labeled[features_continuous].copy()
X_test = scaler.transform(df_continuous)
X_test = pd.DataFrame(X_test, columns=df_continuous.columns)
# categorical features
df_categorical = df_labeled[features_categorical].copy()
X_test = pd.concat((X_test, df_categorical), axis=1)

In [10]:
pred = classifier.predict(X_test)
target = df_labeled['label']

# accuracy
a = accuracy_score(target, pred)
print("Accuracy:", a)

# precision
p = precision_score(target, pred)
print("Precision (Fraud=1):", p)

# recall
r = recall_score(target, pred)
print("Recall (Fraud=1):", r)

# F1 score
f1 = f1_score(target, pred)
print("F1 (Fraud=1):", f1)

# AUC score
auc = roc_auc_score(target, pred)
print("AUC:", auc)

# confusion matrix
m = confusion_matrix(target, pred)
print("\nConfusion matrix: [[TN, FP], [FN, TP]]")
print(m)

cr = classification_report(target, pred)
print("\nFull report:")
print(cr)

Accuracy: 0.9545454545454546
Precision (Fraud=1): 0.5
Recall (Fraud=1): 1.0
F1 (Fraud=1): 0.6666666666666666
AUC: 0.9761904761904762

Confusion matrix: [[TN, FP], [FN, TP]]
[[80  4]
 [ 0  4]]

Full report:
              precision    recall  f1-score   support

         0.0       1.00      0.95      0.98        84
         1.0       0.50      1.00      0.67         4

    accuracy                           0.95        88
   macro avg       0.75      0.98      0.82        88
weighted avg       0.98      0.95      0.96        88



### Step 6: Add SHAP for Local Interpretability (Businesses)

Generate explanation for each client with LFs and SHAPLEY
- LFs that are satisfied
- Top features based on SHAPLEY values

In [11]:
# explain classification model with SHAP
explainer = shap.TreeExplainer(classifier)

In [12]:
def explain_lfs(customer_data):
    """Which rules fired?"""
    fired_rules = []
    for lf, lf_name in zip(lfs, lf_names):
        if lf_name != "" and lf(customer_data) == POSITIVE:
            fired_rules.append(lf_name)
    return fired_rules

def explain_features(transformed_data, threshold=0.5):
    """Top features based on SHAPLEY values"""
    shap_values = explainer.shap_values(transformed_data.to_numpy().reshape(1, -1))
    # crate tuples of features and scores
    feature_impacts = [(f, float(s)) for f, s, in zip(features, shap_values[0])]
    # keep features with absolute scores above 0.5
    top_features = [(f, float(s)) for f, s in zip(features, shap_values[0]) if abs(s) > threshold]
    return top_features

def explain(customer_data, transformed_data, pred, threshold=0.5):
    """LFs that are satisfied anf top features based on SHAPLEY values"""
    # lfs
    lfs_fired = explain_lfs(customer_data)
    # features with SHAP
    features_important = explain_features(transformed_data, threshold)
    features_description = []
    # generate description for top features:
    # e.g. high, low, satisfied, not satisfied
    for f, v in features_important:
        if v >= 0 and p == 1 or v < 0 and p == 0:
            if v >= 0:
                d = "satisfied" if f.startswith(('is_', 'has_')) else "high"
            else:
                d = "not satisfied" if f.startswith(('is_', 'has_')) else "low"
            features_description.append((d, f, v))
    return lfs_fired, features_description

In [13]:
# read labeled data
df_labeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'business_labeled.csv'))

print("Predicted high-risk businesses")
for (i, customer_data), (i, transformed_data), p in zip(df_labeled.iterrows(), X_test.iterrows(), pred):
    if p == 1:
        print('-'*10)
        print(f"Sample {i}")
        lfs_fired, features_description = explain(customer_data, transformed_data, p)
        print('Fired LFs:')
        for l in lfs_fired:
            print("-", l)
        print("Top features & SHAP values:")
        for d, f, v in features_description:
            print("-", d, f, round(v, 2))

Predicted high-risk businesses
----------
Sample 0
Fired LFs:
- Short hold time
Top features & SHAP values:
- high n_txn_total 1.04
- high sum_amt_total 1.01
- high median_hold_time_funds 0.64
- satisfied is_industry_code_7215 4.62
----------
Sample 11
Fired LFs:
- Short hold time
- Many transactions of the same amount
Top features & SHAP values:
- high sum_amt_card 0.73
- satisfied is_industry_code_7761 4.35
----------
Sample 36
Fired LFs:
- High variability in EMT transfers
Top features & SHAP values:
- high cv_amt_eft 0.82
- satisfied is_industry_code_7599 3.83
----------
Sample 40
Fired LFs:
Top features & SHAP values:
- high sum_amt_card 0.62
- satisfied is_industry_code_7761 4.45
----------
Sample 44
Fired LFs:
- Short hold time
- High sale and high debit transfers
- Frequent transaction and short hold time
Top features & SHAP values:
- high n_txn_total 1.12
- high sum_amt_total 1.01
- high median_hold_time_funds 0.66
- satisfied is_industry_code_7215 4.62
----------
Sample 55
Fi

### Step 7: Get final output (Businesses)

In [14]:
# master dataframes (businesses)
df_bsn_labeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'business_labeled.csv'))
df_bsn_unlabeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'business_unlabeled.csv'))
# feature selected dataframes (businesses)
df_bsn_feat_labeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'business_feature_labeled.csv'))
df_bsn_feat_unlabeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'business_feature_unlabeled.csv'))
# all data (businesses)
df_bsn = pd.concat((df_bsn_labeled, df_bsn_unlabeled))
df_bsn_feat = pd.concat((df_bsn_feat_labeled, df_bsn_feat_unlabeled), ignore_index=True)

# normalize continuous features (z-scores)
df_continuous = df_bsn_feat[features_continuous].copy()
df_bsn_transformed = scaler.transform(df_continuous)
df_bsn_transformed = pd.DataFrame(df_bsn_transformed, columns=df_continuous.columns)
# categorical features
df_categorical = df_bsn_feat[features_categorical].copy()
df_bsn_transformed = pd.concat((df_bsn_transformed, df_categorical), axis=1)

In [15]:
# predicted labels
pred = classifier.predict(df_bsn_transformed)

# get explanation for each client
explanation_list = []
for (i, customer_data), (i, transformed_data), p in zip(df_bsn.iterrows(), df_bsn_transformed.iterrows(), pred):
    lfs_fired, features_description = explain(customer_data, transformed_data, p)
    explanation = ""
    if len(lfs_fired) > 0:
        explanation += ', '.join(lfs_fired) + '. '
    if len(features_description) > 0:
        explanation += "Top features: " + ', '.join([d+' '+f for d, f, _ in features_description])
    explanation_list.append(explanation)

In [16]:
df_final = pd.DataFrame({
    'customer_id': df_bsn['customer_id'],
    'label': df_bsn['label'],
    'prediction': pred,
    'explanation': explanation_list
})
df_final.to_csv(os.path.join(FINAL_DIR, 'business_output.csv'), index=False)

## Model for Individuals

### Part 1: Define Snorkel Labeling Functions LFs (Individuals)

In [2]:
@labeling_function()
def high_txn_volume_vs_income(x):
    """High transaction volume relative to income (suspicious if transaction volume is much higher than income)"""
    condition = (x.txn_volume_vs_income > 2477.8) & ((x.age <= 25) | (x.income <= 40000))
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def high_txn_volume_vs_occupation_median(x):
    """Transaction volume vs occupation median (if transaction volume is high compared to typical for the occupation)"""
    condition = (x.txn_volume_vs_occupation_median > 1562.0) & ((x.age <= 25) | (x.income <= 40000))
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def high_wire(x):
    """High wire transfer amounts if income or age is low (could be mule)"""
    condition = (x.median_amt_wire > 6817.2) & ((x.age <= 25) | (x.income <= 40000))
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def frequent_wire(x):
    """Frequent wire transfer amounts if income or age is low"""
    condition = (x.wire_ratio > 0.47) & ((x.age <= 25) | (x.income <= 40000))
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def structuring(x):
    """Structuring (frequent transactions just below the reporting threshold)"""
    condition = (x.count_txn_below_threshold_frequency > 0.009) & (x.n_txn_total > 5)
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def short_hold_time(x):
    """Rapid movement of funds (short hold time)"""
    condition = (x.n_txn_total > 50) & (x.median_hold_time_funds <= 0.11)
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def high_cross_border_ratio(x):
    """Cross-border transactions (if high ratio of cross-border)"""
    condition = (x.cross_border_ratio > 0.17) & (x.transaction_unique_countries >= 3)
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def high_ecommerce_ratio(x):
    """High e-commerce transaction ratio (if e-commerce is used for money laundering)"""
    condition = (x.transaction_ecommerce_ratio > 0.98) & (x.transaction_volume_90d > 198059.4) & (x.n_txn_total > 10)
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def ecommerce_occupation_mismatch(x):
    """High e-commerce for the unemployed"""
    condition = (x.transaction_ecommerce_ratio > 0.98) & (x.occupation_code == 'UNEMPLOYED')
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def unusual_merchant_pattern(x):
    """Make few transactions across many merchants"""
    condition = (x.transaction_unique_merchants > 40) & (x.n_txn_total < (1.5 * x.transaction_unique_merchants))
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def transaction_same_amount(x):
    """Predictable amount patterns through e.g., gift cards"""
    condition = (x.transaction_same_amount_frequency_7d > 0.14) & (x.transaction_round_amount_frequency_7d >= 0.5)
    return POSITIVE if condition else ABSTAIN

In [3]:
LF_DEFINITIONS = [
    ("High transaction to income ratio", high_txn_volume_vs_income),
    ("High transaction to occupation median ratio", high_txn_volume_vs_occupation_median),
    ("High wire transfers", high_wire),
    ("Frequent wire transfers", frequent_wire),
    ("Structuring", structuring),
    ("Short hold time", short_hold_time),
    ("High cross border ratio", high_cross_border_ratio),
    ("High ecommerce ratio", high_ecommerce_ratio),
    ("Ecommerce occupation mismatch", ecommerce_occupation_mismatch),
    ("Unusual merchant pattern", unusual_merchant_pattern),
    ("Transaction same amount", transaction_same_amount),
]

# Separate convenience lists used throughout the notebook
lf_names = [name for name, _ in LF_DEFINITIONS]
lfs = [lf for _, lf in LF_DEFINITIONS]

### Part 2: LF functions Sanity Check (Individuals)

Check the percentage of samples that satisfy each LF

In [4]:
# read labeled data
df_labeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'individual_labeled.csv'))

# Apply LFs to golden_eval
L_dev = np.array([[lf(row) for lf in lfs] for _, row in df_labeled.iterrows()])
lf_analysis = LFAnalysis(L_dev, lfs=lfs).lf_summary()
print(lf_analysis)

                                       j Polarity  Coverage  Overlaps  \
high_txn_volume_vs_income              0      [1]  0.006579  0.002193   
high_txn_volume_vs_occupation_median   1      [1]  0.003289  0.002193   
high_wire                              2      [1]  0.002193  0.002193   
frequent_wire                          3      [1]  0.002193  0.002193   
structuring                            4      [1]  0.004386  0.000000   
short_hold_time                        5      [1]  0.006579  0.000000   
high_cross_border_ratio                6       []  0.000000  0.000000   
high_ecommerce_ratio                   7       []  0.000000  0.000000   
ecommerce_occupation_mismatch          8       []  0.000000  0.000000   
unusual_merchant_pattern               9       []  0.000000  0.000000   
transaction_same_amount               10       []  0.000000  0.000000   

                                      Conflicts  
high_txn_volume_vs_income                   0.0  
high_txn_volume_vs_occu

In [5]:
# read unlabeled data
df_unlabeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'individual_unlabeled.csv'))

# apply LFs to golden_eval
L_dev = np.array([[lf(row) for lf in lfs] for _, row in df_unlabeled.iterrows()])
lf_analysis = LFAnalysis(L_dev, lfs=lfs).lf_summary()
print(lf_analysis)

                                       j Polarity  Coverage  Overlaps  \
high_txn_volume_vs_income              0      [1]  0.036695  0.013164   
high_txn_volume_vs_occupation_median   1      [1]  0.039627  0.014276   
high_wire                              2      [1]  0.004944  0.004005   
frequent_wire                          3      [1]  0.005116  0.003794   
structuring                            4      [1]  0.010347  0.001686   
short_hold_time                        5      [1]  0.004254  0.000019   
high_cross_border_ratio                6      [1]  0.001246  0.000134   
high_ecommerce_ratio                   7      [1]  0.000211  0.000057   
ecommerce_occupation_mismatch          8      [1]  0.000996  0.000134   
unusual_merchant_pattern               9       []  0.000000  0.000000   
transaction_same_amount               10      [1]  0.000977  0.000038   

                                      Conflicts  
high_txn_volume_vs_income                   0.0  
high_txn_volume_vs_occu

### Part 3: Train Snorkel's Generative Model (Individuals)

Snorkel learns to combine LFs into probabilistic labels

In [6]:
# read unlabeled data
df_unlabeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'individual_unlabeled.csv'))

# apply labeling functions to unlabeled data
applier = PandasLFApplier(lfs)
L_unlabeled = applier.apply(df_unlabeled)

100%|██████████| 52187/52187 [00:09<00:00, 5350.19it/s]


In [None]:
# fit Snorkel AI weak supervision model on unlabeled data
label_model = LabelModel(cardinality=2, verbose=True)
label_model.fit(L_unlabeled, n_epochs=50, log_freq=10, seed=1)

# save Snorkel model
label_model_path = os.path.join(CHECKPOINT_DIR, 'snorkel_label_ind.pkl')
with open(label_model_path, 'wb') as f:
    pickle.dump(label_model, f)

INFO:root:Computing O...
INFO:root:Estimating \mu...
  0%|          | 0/50 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.002]
INFO:root:[10 epochs]: TRAIN:[loss=0.001]
INFO:root:[20 epochs]: TRAIN:[loss=0.000]
INFO:root:[30 epochs]: TRAIN:[loss=0.000]
INFO:root:[40 epochs]: TRAIN:[loss=0.000]
100%|██████████| 50/50 [00:00<00:00, 1014.13epoch/s]
INFO:root:Finished Training


In [7]:
# read Snorkel model
label_model_path = os.path.join(CHECKPOINT_DIR, 'snorkel_label_ind.pkl')
with open(label_model_path, 'rb') as f:
    label_model = pickle.load(f)

# predict labels for unlabeled data
pred_probs_unlabeled = label_model.predict_proba(L_unlabeled)
pred_unlabeled = label_model.predict(L_unlabeled)
# all predicted 1s are high-risk, while all predicted 0s or abstains are low-risk
pred_unlabeled = np.where(pred_unlabeled==1, 1, 0)

### Step 4: Train Final Classification Model (Individuals)

Train classification model on newly labeled data

In [8]:
# read unlabeled data
df_unlabeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'individual_feature_unlabeled.csv'))

# continuous features
features_continuous = [c for c in df_unlabeled if c != 'label' and len(df_unlabeled[c].unique()) > 2]
df_continuous = df_unlabeled[features_continuous].copy()
# fit normalization scaler
scaler = StandardScaler()
scaler.fit(df_continuous)
# save normalization scaler
scaler_path = os.path.join(CHECKPOINT_DIR, 'scaler_ind.pkl')
with open(scaler_path, 'wb') as f:
    pickle.dump(scaler, f)
# transform continuous data (z-scores)
X_train = scaler.transform(df_continuous)
X_train = pd.DataFrame(X_train, columns=df_continuous.columns)

# categorical features
features_categorical = [c for c in df_unlabeled if c != 'label' and len(df_unlabeled[c].unique()) <= 2]
df_categorical = df_unlabeled[features_categorical].copy()
X_train = pd.concat((X_train, df_categorical), axis=1)

# fill missing values with zeros
X_train = X_train.fillna(value=0)

# all features
features = features_continuous + features_categorical

# labels
y_train = pred_unlabeled

In [10]:
# train XGBoost (inherently interpretable with feature importance)
classifier = xgb.XGBClassifier(
    n_estimators=100,
    max_depth=6,
    learning_rate=0.1,
    random_state=1,
    eval_metric='logloss'
)
classifier.fit(X_train, y_train)

# save XGBoost model
classifier_path = os.path.join(CHECKPOINT_DIR, 'xgboost_ind.pkl')
with open(classifier_path, 'wb') as f:
    pickle.dump(classifier, f)

### Step 5: Model Performance (Individuals)

Evaluate performance of classification model on labeled data

In [9]:
# read XGBoost model
classifier_path = os.path.join(CHECKPOINT_DIR, 'xgboost_ind.pkl')
with open(classifier_path, 'rb') as f:
    classifier = pickle.load(f)

# read normalization scaler
scaler_path = os.path.join(CHECKPOINT_DIR, 'scaler_ind.pkl')
with open(scaler_path, 'rb') as f:
    scaler = pickle.load(f)

# read labeled data
df_labeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'individual_feature_labeled.csv'))

# normalize continuous features (z-scores)
df_continuous = df_labeled[features_continuous].copy()
X_test = scaler.transform(df_continuous)
X_test = pd.DataFrame(X_test, columns=df_continuous.columns)
# categorical features
df_categorical = df_labeled[features_categorical].copy()
X_test = pd.concat((X_test, df_categorical), axis=1)

In [10]:
pred = classifier.predict(X_test)
target = df_labeled['label']

# accuracy
a = accuracy_score(target, pred)
print("Accuracy:", a)

# precision
p = precision_score(target, pred)
print("Precision (Fraud=1):", p)

# recall
r = recall_score(target, pred)
print("Recall (Fraud=1):", r)

# F1 score
f1 = f1_score(target, pred)
print("F1 (Fraud=1):", f1)

# AUC score
auc = roc_auc_score(target, pred)
print("AUC:", auc)

# confusion matrix
m = confusion_matrix(target, pred)
print("\nConfusion matrix: [[TN, FP], [FN, TP]]")
print(m)

cr = classification_report(target, pred)
print("\nFull report:")
print(cr)

Accuracy: 0.9857456140350878
Precision (Fraud=1): 0.26666666666666666
Recall (Fraud=1): 0.6666666666666666
F1 (Fraud=1): 0.38095238095238093
AUC: 0.8272626931567328

Confusion matrix: [[TN, FP], [FN, TP]]
[[895  11]
 [  2   4]]

Full report:
              precision    recall  f1-score   support

         0.0       1.00      0.99      0.99       906
         1.0       0.27      0.67      0.38         6

    accuracy                           0.99       912
   macro avg       0.63      0.83      0.69       912
weighted avg       0.99      0.99      0.99       912



### Step 6: Add SHAP for Local Interpretability (Individuals)

Generate explanation for each client with LFs and SHAPLEY
- LFs that are satisfied
- Top features based on SHAPLEY values

In [11]:
# explain classification model with SHAP
explainer = shap.TreeExplainer(classifier)

In [12]:
def explain_lfs(customer_data):
    """Which rules fired?"""
    fired_rules = []
    for lf, lf_name in zip(lfs, lf_names):
        if lf_name != "" and lf(customer_data) == POSITIVE:
            fired_rules.append(lf_name)
    return fired_rules

def explain_features(transformed_data, threshold=0.5):
    """Top features based on SHAPLEY values"""
    shap_values = explainer.shap_values(transformed_data.to_numpy().reshape(1, -1))
    # crate tuples of features and scores
    feature_impacts = [(f, float(s)) for f, s, in zip(features, shap_values[0])]
    # keep features with absolute scores above 0.5
    top_features = [(f, float(s)) for f, s in zip(features, shap_values[0]) if abs(s) > threshold]
    return top_features

def explain(customer_data, transformed_data, pred, threshold=0.5):
    """LFs that are satisfied anf top features based on SHAPLEY values"""
    # lfs
    lfs_fired = explain_lfs(customer_data)
    # features with SHAP
    features_important = explain_features(transformed_data, threshold)
    features_description = []
    # generate description for top features:
    # e.g. high, low, satisfied, not satisfied
    for f, v in features_important:
        if v >= 0 and p == 1 or v < 0 and p == 0:
            if v >= 0:
                d = "satisfied" if f.startswith(('is_', 'has_')) else "high"
            else:
                d = "not satisfied" if f.startswith(('is_', 'has_')) else "low"
            features_description.append((d, f, v))
    return lfs_fired, features_description

In [14]:
# read labeled data
df_labeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'individual_labeled.csv'))

print("Predicted high-risk individuals")
for (i, customer_data), (i, transformed_data), p in zip(df_labeled.iterrows(), X_test.iterrows(), pred):
    if p == 1:
        print('-'*10)
        print(f"Sample {i}")
        lfs_fired, features_description = explain(customer_data, transformed_data, p)
        print('Fired LFs:')
        for l in lfs_fired:
            print("-", l)
        print("Top features & SHAP values:")
        for d, f, v in features_description:
            print("-", d, f, round(v, 2))

Predicted high-risk individuals
----------
Sample 31
Fired LFs:
- High transaction to income ratio
Top features & SHAP values:
- high income 5.34
- high sum_amt_total 1.69
- high channel_credit_volume 1.24
----------
Sample 42
Fired LFs:
- Short hold time
Top features & SHAP values:
- high channel_credit_volume 2.27
- high median_hold_time_funds 2.27
- high n_txn_total 2.34
----------
Sample 60
Fired LFs:
- Short hold time
Top features & SHAP values:
- high channel_credit_volume 2.41
- high median_hold_time_funds 2.27
- high n_txn_total 2.18
----------
Sample 108
Fired LFs:
Top features & SHAP values:
- high income 5.02
----------
Sample 179
Fired LFs:
- Short hold time
Top features & SHAP values:
- high channel_credit_volume 2.49
- high median_hold_time_funds 2.25
- high n_txn_total 2.43
----------
Sample 246
Fired LFs:
- High transaction to income ratio
Top features & SHAP values:
- high income 3.7
- high sum_amt_total 2.17
- high channel_credit_volume 0.54
----------
Sample 389
Fire

### Step 7: Get final output (Individuals)

In [16]:
# master dataframes (individuals)
df_ind_labeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'individual_labeled.csv'))
df_ind_unlabeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'individual_labeled.csv'))
# feature selected dataframes (individuals)
df_ind_feat_labeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'individual_feature_labeled.csv'))
df_ind_feat_unlabeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'individual_feature_labeled.csv'))
# all data (individuals)
df_ind = pd.concat((df_ind_labeled, df_ind_unlabeled))
df_ind_feat = pd.concat((df_ind_feat_labeled, df_ind_feat_unlabeled), ignore_index=True)

# normalize continuous features (z-scores)
df_continuous = df_ind_feat[features_continuous].copy()
df_ind_transformed = scaler.transform(df_continuous)
df_ind_transformed = pd.DataFrame(df_ind_transformed, columns=df_continuous.columns)
# categorical features
df_categorical = df_ind_feat[features_categorical].copy()
df_ind_transformed = pd.concat((df_ind_transformed, df_categorical), axis=1)

In [17]:
# predicted labels
pred = classifier.predict(df_ind_transformed)

# get explanation for each client
explanation_list = []
for (i, customer_data), (i, transformed_data), p in zip(df_ind.iterrows(), df_ind_transformed.iterrows(), pred):
    lfs_fired, features_description = explain(customer_data, transformed_data, p)
    explanation = ""
    if len(lfs_fired) > 0:
        explanation += ', '.join(lfs_fired) + '. '
    if len(features_description) > 0:
        explanation += "Top features: " + ', '.join([d+' '+f for d, f, _ in features_description])
    explanation_list.append(explanation)

In [19]:
df_final = pd.DataFrame({
    'customer_id': df_ind['customer_id'],
    'label': df_ind['label'],
    'prediction': pred,
    'explanation': explanation_list
})
df_final.to_csv(os.path.join(FINAL_DIR, 'individual_output.csv'), index=False)