# Weak Supervision with Snorkel AI

Create labels for unlabeled data based on some heuristics.

In [1]:
import pandas as pd
import numpy as np
import os
# Snorkel AI weak supervision
from snorkel.labeling import LFAnalysis
from snorkel.labeling import labeling_function, PandasLFApplier
from snorkel.labeling.model import LabelModel
# classification model
from sklearn.preprocessing import StandardScaler
import xgboost as xgb
# SHAP explainability
import shap
# save model
import pickle
# evaluation metrics
from sklearn.metrics import (
    confusion_matrix, 
    accuracy_score, 
    precision_score, 
    recall_score,
    f1_score,
    roc_auc_score,
    classification_report
)

# directories
RAW_DIR = 'data/raw'
PROCESSED_DIR = 'data/processed'
CHECKPOINT_DIR = 'checkpoint'

# weak label constants
ABSTAIN = -1  # Uncertain
NEGATIVE = 0  # Low risk
POSITIVE = 1  # High risk

  from .autonotebook import tqdm as notebook_tqdm


## Part 1: Define Snorkel Labeling Functions (LFs)

In [2]:
@labeling_function()
def high_txn_volume_vs_income(x):
    """High transaction volume relative to income (suspicious if transaction volume is much higher than income)"""
    condition = (x.txn_volume_vs_income > 2477.8) & ((x.age <= 25) | (x.income <= 40000))
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def high_txn_volume_vs_occupation_median(x):
    """Transaction volume vs occupation median (if transaction volume is high compared to typical for the occupation)"""
    condition = (x.txn_volume_vs_occupation_median > 1562.0) & ((x.age <= 25) | (x.income <= 40000))
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def high_wire(x):
    """High wire transfer amounts if income or age is low (could be mule)"""
    condition = (x.median_amt_wire > 6817.2) & ((x.age <= 25) | (x.income <= 40000))
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def frequent_wire(x):
    """Frequent wire transfer amounts if income or age is low"""
    condition = (x.wire_ratio > 0.47) & ((x.age <= 25) | (x.income <= 40000))
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def structuring(x):
    """Structuring (frequent transactions just below the reporting threshold)"""
    condition = (x.count_txn_below_threshold_frequency > 0.009) & (x.n_txn_total > 5)
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def short_hold_time(x):
    """Rapid movement of funds (short hold time)"""
    condition = (x.n_txn_total > 50) & (x.median_hold_time_funds <= 0.11)
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def high_cross_border_ratio(x):
    """Cross-border transactions (if high ratio of cross-border)"""
    condition = (x.cross_border_ratio > 0.17) & (x.transaction_unique_countries >= 3)
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def high_ecommerce_ratio(x):
    """High e-commerce transaction ratio (if e-commerce is used for money laundering)"""
    condition = (x.transaction_ecommerce_ratio > 0.98) & (x.transaction_volume_90d > 198059.4) & (x.n_txn_total > 10)
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def ecommerce_occupation_mismatch(x):
    """High e-commerce for the unemployed"""
    condition = (x.transaction_ecommerce_ratio > 0.98) & (x.occupation_code == 'UNEMPLOYED')
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def unusual_merchant_pattern(x):
    """Make few transactions across many merchants"""
    condition = (x.transaction_unique_merchants > 40) & (x.n_txn_total < (1.5 * x.transaction_unique_merchants))
    return POSITIVE if condition else ABSTAIN

@labeling_function()
def transaction_same_amount(x):
    """Predictable amount patterns through e.g., gift cards"""
    condition = (x.transaction_same_amount_frequency_7d > 0.14) & (x.transaction_round_amount_frequency_7d >= 0.5)
    return POSITIVE if condition else ABSTAIN

In [3]:
LF_DEFINITIONS = [
    ("High transaction to income ratio", high_txn_volume_vs_income),
    ("High transaction to occupation median ratio", high_txn_volume_vs_occupation_median),
    ("High wire transfers", high_wire),
    ("Frequent wire transfers", frequent_wire),
    ("Structuring", structuring),
    ("Short hold time", short_hold_time),
    ("High cross border ratio", high_cross_border_ratio),
    ("High ecommerce ratio", high_ecommerce_ratio),
    ("Ecommerce occupation mismatch", ecommerce_occupation_mismatch),
    ("Unusual merchant pattern", unusual_merchant_pattern),
    ("Transaction same amount", transaction_same_amount),
]

# Separate convenience lists used throughout the notebook
lf_names = [name for name, _ in LF_DEFINITIONS]
lfs = [lf for _, lf in LF_DEFINITIONS]

## Part 2: LF functions Sanity Check

In [4]:
# read labeled data
df_labeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'individual_labeled.csv'))

# Apply LFs to golden_eval
L_dev = np.array([[lf(row) for lf in lfs] for _, row in df_labeled.iterrows()])
lf_analysis = LFAnalysis(L_dev, lfs=lfs).lf_summary()
print(lf_analysis)

                                       j Polarity  Coverage  Overlaps  \
high_txn_volume_vs_income              0      [1]  0.006579  0.002193   
high_txn_volume_vs_occupation_median   1      [1]  0.003289  0.002193   
high_wire                              2      [1]  0.002193  0.002193   
frequent_wire                          3      [1]  0.002193  0.002193   
structuring                            4      [1]  0.004386  0.000000   
short_hold_time                        5      [1]  0.006579  0.000000   
high_cross_border_ratio                6       []  0.000000  0.000000   
high_ecommerce_ratio                   7       []  0.000000  0.000000   
ecommerce_occupation_mismatch          8       []  0.000000  0.000000   
unusual_merchant_pattern               9       []  0.000000  0.000000   
transaction_same_amount               10       []  0.000000  0.000000   

                                      Conflicts  
high_txn_volume_vs_income                   0.0  
high_txn_volume_vs_occu

In [5]:
# read unlabeled data
df_unlabeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'individual_unlabeled.csv'))

# apply LFs to golden_eval
L_dev = np.array([[lf(row) for lf in lfs] for _, row in df_unlabeled.iterrows()])
lf_analysis = LFAnalysis(L_dev, lfs=lfs).lf_summary()
print(lf_analysis)

                                       j Polarity  Coverage  Overlaps  \
high_txn_volume_vs_income              0      [1]  0.036695  0.013164   
high_txn_volume_vs_occupation_median   1      [1]  0.039627  0.014276   
high_wire                              2      [1]  0.004944  0.004005   
frequent_wire                          3      [1]  0.005116  0.003794   
structuring                            4      [1]  0.010347  0.001686   
short_hold_time                        5      [1]  0.004254  0.000019   
high_cross_border_ratio                6      [1]  0.001246  0.000134   
high_ecommerce_ratio                   7      [1]  0.000211  0.000057   
ecommerce_occupation_mismatch          8      [1]  0.000996  0.000134   
unusual_merchant_pattern               9       []  0.000000  0.000000   
transaction_same_amount               10      [1]  0.000977  0.000038   

                                      Conflicts  
high_txn_volume_vs_income                   0.0  
high_txn_volume_vs_occu

## Part 3: Train Snorkel's Generative Model

Snorkel learns to combine LFs into probabilistic labels

In [6]:
# read unlabeled data
df_unlabeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'individual_unlabeled.csv'))

# apply labeling functions
applier = PandasLFApplier(lfs)
L_unlabeled = applier.apply(df_unlabeled)

100%|██████████| 52187/52187 [00:09<00:00, 5324.70it/s]


In [7]:
# fit Snorkel AI weak supervision model
label_model = LabelModel(cardinality=2, verbose=True)
label_model.fit(L_unlabeled, n_epochs=50, log_freq=10, seed=1)

# save Snorkel label model
label_model_path = os.path.join(CHECKPOINT_DIR, 'snorkel_label_ind.pkl')
with open(label_model_path, 'wb') as f:
    pickle.dump(label_model, f)

INFO:root:Computing O...
INFO:root:Estimating \mu...
  0%|          | 0/50 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.002]
INFO:root:[10 epochs]: TRAIN:[loss=0.001]
INFO:root:[20 epochs]: TRAIN:[loss=0.000]
INFO:root:[30 epochs]: TRAIN:[loss=0.000]
INFO:root:[40 epochs]: TRAIN:[loss=0.000]
100%|██████████| 50/50 [00:00<00:00, 1014.13epoch/s]
INFO:root:Finished Training


In [8]:
# read Snorkel label model
label_model_path = os.path.join(CHECKPOINT_DIR, 'snorkel_label_ind.pkl')
with open(label_model_path, 'rb') as f:
    label_model = pickle.load(f)

# predict labels for unlabeled data
pred_probs_unlabeled = label_model.predict_proba(L_unlabeled)
pred_unlabeled = label_model.predict(L_unlabeled)
# all predicted 1s are high-risk, while all predicted 0s or abstains are low-risk
pred_unlabeled = np.where(pred_unlabeled==1, 1, 0)

## Step 4: Train Final Classification Model

In [9]:
# read unlabeled data
df_unlabeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'individual_feature_unlabeled.csv'))

# continuous features
features_continuous = [c for c in df_unlabeled if c != 'label' and len(df_unlabeled[c].unique()) > 2]
df_continuous = df_unlabeled[features_continuous].copy()
# fit normalization scaler
scaler = StandardScaler()
scaler.fit(df_continuous)
# save normalization scaler
scaler_path = os.path.join(CHECKPOINT_DIR, 'scaler_ind.pkl')
with open(scaler_path, 'wb') as f:
    pickle.dump(scaler, f)
# transform continuous data (z-scores)
X_train = scaler.transform(df_continuous)
X_train = pd.DataFrame(X_train, columns=df_continuous.columns)

# categorical features
features_categorical = [c for c in df_unlabeled if c != 'label' and len(df_unlabeled[c].unique()) <= 2]
df_categorical = df_unlabeled[features_categorical].copy()
X_train = pd.concat((X_train, df_categorical), axis=1)

# fill missing values with zeros
X_train = X_train.fillna(value=0)

# all features
features = features_continuous + features_categorical

# labels
y_train = pred_unlabeled

In [10]:
# train XGBoost (inherently interpretable with feature importance)
classifier = xgb.XGBClassifier(
    n_estimators=100,
    max_depth=6,
    learning_rate=0.1,
    random_state=1,
    eval_metric='logloss'
)
classifier.fit(X_train, y_train)

# save XGBoost model
classifier_path = os.path.join(CHECKPOINT_DIR, 'xgboost_ind.pkl')
with open(classifier_path, 'wb') as f:
    pickle.dump(classifier, f)

## Step 5: Model Performance

In [11]:
# read XGBoost model
classifier_path = os.path.join(CHECKPOINT_DIR, 'xgboost_ind.pkl')
with open(classifier_path, 'rb') as f:
    classifier = pickle.load(f)

# read normalization scaler
scaler_path = os.path.join(CHECKPOINT_DIR, 'scaler_ind.pkl')
with open(scaler_path, 'rb') as f:
    scaler = pickle.load(f)

# read labeled data
df_labeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'individual_feature_labeled.csv'))

# normalize continuous features (z-scores)
df_continuous = df_labeled[features_continuous].copy()
X_test = scaler.transform(df_continuous)
X_test = pd.DataFrame(X_test, columns=df_continuous.columns)
# categorical features
df_categorical = df_labeled[features_categorical].copy()
X_test = pd.concat((X_test, df_categorical), axis=1)

In [12]:
pred = classifier.predict(X_test)
target = df_labeled['label']

# accuracy
a = accuracy_score(target, pred)
print("Accuracy:", a)

# precision
p = precision_score(target, pred)
print("Precision (Fraud=1):", p)

# recall
r = recall_score(target, pred)
print("Recall (Fraud=1):", r)

# F1 score
f1 = f1_score(target, pred)
print("F1 (Fraud=1):", f1)

# AUC score
auc = roc_auc_score(target, pred)
print("AUC:", auc)

# confusion matrix
m = confusion_matrix(target, pred)
print("\nConfusion matrix: [[TN, FP], [FN, TP]]")
print(m)

cr = classification_report(target, pred)
print("\nFull report:")
print(cr)

Accuracy: 0.9857456140350878
Precision (Fraud=1): 0.26666666666666666
Recall (Fraud=1): 0.6666666666666666
F1 (Fraud=1): 0.38095238095238093
AUC: 0.8272626931567328

Confusion matrix: [[TN, FP], [FN, TP]]
[[895  11]
 [  2   4]]

Full report:
              precision    recall  f1-score   support

         0.0       1.00      0.99      0.99       906
         1.0       0.27      0.67      0.38         6

    accuracy                           0.99       912
   macro avg       0.63      0.83      0.69       912
weighted avg       0.99      0.99      0.99       912



## Step 6: Add SHAP for Local Interpretability

In [13]:
# explain classification model with SHAP
explainer = shap.TreeExplainer(classifier)

In [14]:
def explain_lfs(customer_data):
    """Which rules fired?"""
    fired_rules = []
    for lf, lf_name in zip(lfs, lf_names):
        if lf_name != "" and lf(customer_data) == POSITIVE:
            fired_rules.append(lf_name)
    return fired_rules

def explain_features(customer_data):
    """Top features by SHAP"""
    shap_values = explainer.shap_values(customer_data.to_numpy().reshape(1, -1))
    # crate tuples of features and scores
    feature_impacts = [(f, float(s)) for f, s, in zip(features, shap_values[0])]
    # keep features with absolute scores above 0.5
    threshold = 0.5
    top_features = [(f, float(s)) for f, s in zip(features, shap_values[0]) if abs(s) > threshold]
    return top_features

In [15]:
df_labeled = pd.read_csv(os.path.join(PROCESSED_DIR, 'individual_labeled.csv'))

print("High-risk individuals & LFs")
for i, row in df_labeled[pred==1].iterrows():
    print('-'*10)
    print(f"Sample {i}")
    lfs_fired = explain_lfs(row)
    print('Fired LFs:', ', '.join(lfs_fired))

High-risk individuals & LFs
----------
Sample 31
Fired LFs: High transaction to income ratio
----------
Sample 42
Fired LFs: Short hold time
----------
Sample 60
Fired LFs: Short hold time
----------
Sample 108
Fired LFs: 
----------
Sample 179
Fired LFs: Short hold time
----------
Sample 246
Fired LFs: High transaction to income ratio
----------
Sample 389
Fired LFs: Short hold time
----------
Sample 470
Fired LFs: Short hold time
----------
Sample 482
Fired LFs: Short hold time
----------
Sample 513
Fired LFs: High transaction to income ratio, High transaction to occupation median ratio
----------
Sample 615
Fired LFs: High transaction to income ratio
----------
Sample 623
Fired LFs: High wire transfers, Frequent wire transfers
----------
Sample 675
Fired LFs: High transaction to occupation median ratio
----------
Sample 687
Fired LFs: High transaction to income ratio, High transaction to occupation median ratio, High wire transfers, Frequent wire transfers
----------
Sample 741
Fire

In [16]:
print("High-risk individuals, top features & SHAP values")
for i, row in X_test[pred==1].iterrows():
    print('-'*10)
    print(f"Sample {i}")
    features_important = explain_features(row)
    print("Top features & SHAP values:")
    for f, v in features_important:
        print(f, round(v, 2))

High-risk individuals, top features & SHAP values
----------
Sample 31
Top features & SHAP values:
income 5.34
sum_amt_total 1.69
channel_debit_volume -0.57
channel_credit_volume 1.24
----------
Sample 42
Top features & SHAP values:
channel_credit_volume 2.27
median_hold_time_funds 2.27
n_txn_total 2.34
----------
Sample 60
Top features & SHAP values:
channel_credit_volume 2.41
median_hold_time_funds 2.27
n_txn_total 2.18
----------
Sample 108
Top features & SHAP values:
income 5.02
n_txn_total -0.5
----------
Sample 179
Top features & SHAP values:
channel_credit_volume 2.49
median_hold_time_funds 2.25
n_txn_total 2.43
----------
Sample 246
Top features & SHAP values:
income 3.7
sum_amt_total 2.17
channel_credit_volume 0.54
----------
Sample 389
Top features & SHAP values:
channel_credit_volume 2.38
median_hold_time_funds 2.24
n_txn_total 2.42
----------
Sample 470
Top features & SHAP values:
channel_credit_volume 2.46
median_hold_time_funds 2.25
n_txn_total 2.36
----------
Sample 482
