# Layer 2: ML Classification with Snowflake Model Registry

### What We're Building:
- **XGBoost Classifier** trained on relative semantic risk scores
- **Model Registry** - Version-controlled, auditable ML models
- **Feature Store → Model lineage** - Full traceability

In [None]:
from snowflake.snowpark import Session
from snowflake.snowpark.functions import col
from snowflake.ml.feature_store import FeatureStore
from snowflake.ml.registry import Registry
from snowflake.ml.modeling.xgboost import XGBClassifier

session = Session.builder.getOrCreate()
session.use_warehouse('COMPLIANCE_DEMO_WH')
session.use_database('COMPLIANCE_DEMO')
session.use_schema('ML')

print("Layer 2: Training ML model on relative risk scores...")

## Step 1: Load Features from Feature Store

In [None]:
fs = FeatureStore(
    session=session,
    database="COMPLIANCE_DEMO",
    name="ML",
    default_warehouse="COMPLIANCE_DEMO_WH"
)

semantic_fv = fs.get_feature_view("EMAIL_SEMANTIC_FEATURES", "V1")
print(f"Loaded Feature View: {semantic_fv.name}/V1")

In [None]:
features_df = session.table('COMPLIANCE_DEMO.ML.EMAIL_SEMANTIC_FEATURES')

print(f"Total samples: {features_df.count():,}")
print(f"Violation rate: {features_df.filter(col('IS_VIOLATION') == 1).count() / features_df.count() * 100:.1f}%")

## Step 2: Prepare Training Data

Features are **relative risk scores**: negative = normal, positive = risky

In [None]:
feature_cols = [
    'BASELINE_SIMILARITY',
    'MNPI_RISK_SCORE', 
    'CONFIDENTIALITY_RISK_SCORE', 
    'PERSONAL_TRADING_RISK_SCORE',
    'INFO_BARRIER_RISK_SCORE',
    'CROSS_BARRIER_FLAG'
]

print("Features (semantic risk scores):")
for f in feature_cols:
    print(f"  - {f}")

train_df, test_df = features_df.random_split([0.8, 0.2], seed=42)
print(f"\nTrain: {train_df.count():,}, Test: {test_df.count():,}")

## Step 3: Train XGBoost Model

In [None]:
xgb_model = XGBClassifier(
    input_cols=feature_cols,
    label_cols=['IS_VIOLATION'],
    output_cols=['PREDICTED_VIOLATION'],
    n_estimators=100,
    max_depth=5,
    learning_rate=0.1
)

print("Training XGBoost on semantic risk scores...")
xgb_model.fit(train_df)
print("Model trained!")

## Step 4: Evaluate Model Performance

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

predictions = xgb_model.predict(test_df)
pred_pd = predictions.to_pandas()

accuracy = accuracy_score(pred_pd['IS_VIOLATION'], pred_pd['PREDICTED_VIOLATION'])
precision = precision_score(pred_pd['IS_VIOLATION'], pred_pd['PREDICTED_VIOLATION'])
recall = recall_score(pred_pd['IS_VIOLATION'], pred_pd['PREDICTED_VIOLATION'])
f1 = f1_score(pred_pd['IS_VIOLATION'], pred_pd['PREDICTED_VIOLATION'])

print("\n" + "="*60)
print("MODEL PERFORMANCE ON TEST SET")
print("="*60)
print(f"\nAccuracy:  {accuracy*100:.1f}%")
print(f"Precision: {precision*100:.1f}%")
print(f"Recall:    {recall*100:.1f}%")
print(f"F1 Score:  {f1*100:.1f}%")

## Step 5: Register Model in Snowflake Model Registry

In [None]:
reg = Registry(session=session, database_name='COMPLIANCE_DEMO', schema_name='ML')

try:
    reg.delete_model('EMAIL_COMPLIANCE_CLASSIFIER')
    print("Deleted existing model")
except:
    pass

In [None]:
from snowflake.ml.model import task as model_task

model_version = reg.log_model(
    model=xgb_model,
    model_name='EMAIL_COMPLIANCE_CLASSIFIER',
    version_name='V1',
    conda_dependencies=['xgboost', 'scikit-learn'],
    metrics={
        'accuracy': float(accuracy),
        'precision': float(precision),
        'recall': float(recall),
        'f1_score': float(f1),
        'training_samples': int(train_df.count()),
        'test_samples': int(test_df.count())
    },
    sample_input_data=train_df.select(*feature_cols).limit(100),
    task=model_task.Task.TABULAR_BINARY_CLASSIFICATION,
    comment='XGBoost on relative semantic risk scores (risk - baseline similarity)'
)

print("\n" + "="*60)
print("MODEL REGISTERED IN SNOWFLAKE MODEL REGISTRY")
print("="*60)
print(f"\nModel: EMAIL_COMPLIANCE_CLASSIFIER/V1")
print(f"Features: {feature_cols}")
print(f"\nMetrics stored in registry for audit trail.")

## Step 6: Run Inference at Scale

In [None]:
import time

all_emails_df = session.table('COMPLIANCE_DEMO.ML.EMAIL_SEMANTIC_FEATURES')

start = time.time()
scored_df = xgb_model.predict_proba(all_emails_df)
scored_df.write.mode('overwrite').save_as_table('MODEL_PREDICTIONS_RAW')
elapsed = time.time() - start

count = session.sql('SELECT COUNT(*) as cnt FROM MODEL_PREDICTIONS_RAW').collect()[0]['CNT']
print(f"\nScored {count:,} emails in {elapsed:.1f} seconds (with probabilities)")

In [None]:
session.sql("""
CREATE OR REPLACE TABLE MODEL_PREDICTIONS_V1 AS
SELECT 
    *,
    PREDICT_PROBA_1 as VIOLATION_PROBABILITY,
    CASE 
        WHEN PREDICT_PROBA_1 >= 0.7 THEN 'HIGH_RISK'
        WHEN PREDICT_PROBA_1 <= 0.3 THEN 'LOW_RISK'
        ELSE 'NEEDS_REVIEW'
    END as ML_DECISION
FROM MODEL_PREDICTIONS_RAW
""").collect()

print("\n" + "="*70)
print("THREE-WAY ML CLASSIFICATION")
print("="*70)
print("\n  HIGH_RISK (prob >= 0.7):    Auto-escalate to compliance")
print("  NEEDS_REVIEW (0.3-0.7):     Send to LLM for deep analysis")
print("  LOW_RISK (prob <= 0.3):     Auto-clear")

session.sql("""
SELECT 
    ML_DECISION,
    COUNT(*) as EMAIL_COUNT,
    ROUND(COUNT(*) * 100.0 / SUM(COUNT(*)) OVER(), 1) as PCT,
    SUM(IS_VIOLATION) as ACTUAL_VIOLATIONS,
    ROUND(SUM(IS_VIOLATION) * 100.0 / COUNT(*), 1) as VIOLATION_RATE
FROM MODEL_PREDICTIONS_V1
GROUP BY 1
ORDER BY VIOLATION_RATE DESC
""").show()

## Why This Approach Works

**ML on semantic features beats keyword rules because:**

1. **Captures meaning, not words** - "Let's keep this quiet" matches secrecy concept even without "confidential"
2. **Hard to evade** - Violators can change vocabulary, but meaning still clusters near risk concepts
3. **Learns combinations** - XGBoost finds complex patterns (high secrecy + high urgency = very risky)
4. **Relative scoring** - Normalizes against baseline business language

**Limitation:** ML alone can't understand nuance or context. That's where LLMs come in (next layer).

---

**Next:** Add LLM for nuanced analysis →