# üöï DQX Anomaly Detection - NYC Taxi Data Quality

## Find Real Data Quality Issues in Real Data

This notebook demonstrates DQX anomaly detection on the **NYC Taxi dataset** - real trip data with **real data quality issues**.

**What we'll do:**
1. Load NYC Taxi data (available in Databricks samples)
2. Train DQX anomaly detection model (unsupervised)
3. Find anomalies and see they represent real DQ issues
4. Understand WHY records are flagged

**Time**: ~10 minutes


---

## Section 1: Setup


In [None]:
# Imports
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import *

from databricks.labs.dqx.anomaly import AnomalyEngine, has_no_anomalies
from databricks.labs.dqx.engine import DQEngine
from databricks.labs.dqx.rule import DQDatasetRule
from databricks.sdk import WorkspaceClient

# Initialize
spark = SparkSession.builder.getOrCreate()
ws = WorkspaceClient()
anomaly_engine = AnomalyEngine(ws)
dq_engine = DQEngine(ws)

print("‚úÖ Setup complete!")
print(f"   Spark version: {spark.version}")


---

## Section 2: Load NYC Taxi Data

The NYC Taxi dataset is available in the Databricks `samples` catalog.


In [None]:
# Load NYC Taxi data from samples catalog
# Try different locations based on workspace setup
taxi_tables = [
    "samples.nyctaxi.trips",
]

df_taxi = None
for table in taxi_tables:
    try:
        df_taxi = spark.table(table)
        print(f"‚úÖ Loaded taxi data from: {table}")
        break
    except Exception:
        continue

if df_taxi is None:
    raise ValueError("Could not find NYC Taxi data. Check samples catalog.")

# Show schema
print(f"\nüìã Schema:")
df_taxi.printSchema()


In [None]:
# Sample and prepare data for anomaly detection
# Use 50K rows for quick demo (full dataset has millions)
SAMPLE_SIZE = 50000

print(f"üîÑ Sampling {SAMPLE_SIZE:,} trips for demo...\n")

# First, let's see what columns are available
print("Available columns:", df_taxi.columns)

# Select relevant columns and add computed features
# Note: Databricks samples.nyctaxi.trips schema varies - adjust as needed
df_sample = (
    df_taxi
    .sample(fraction=0.1, seed=42)  # Random sample
    .limit(SAMPLE_SIZE)
    .withColumn("trip_id", F.monotonically_increasing_id())  # Add ID for merging
    .withColumn(
        "trip_duration_mins",
        (F.unix_timestamp("tpep_dropoff_datetime") - F.unix_timestamp("tpep_pickup_datetime")) / 60
    )
    .withColumn(
        "speed_mph",
        F.when(F.col("trip_duration_mins") > 0,
               F.col("trip_distance") / (F.col("trip_duration_mins") / 60))
        .otherwise(0)
    )
    .select(
        "trip_id",
        "tpep_pickup_datetime",
        "tpep_dropoff_datetime",
        "trip_distance",
        "trip_duration_mins",
        "speed_mph",
        "fare_amount",
        "pickup_zip",
        "dropoff_zip"
    )
    .filter(F.col("trip_duration_mins").isNotNull())  # Remove nulls
)

print(f"\n‚úÖ Prepared {df_sample.count():,} trips")
print(f"\nüìä Sample data:")
display(df_sample.limit(10))


In [None]:
# Quick stats to understand the data
print("üìà Data Statistics:\n")

stats = df_sample.select(
    F.count("*").alias("total_trips"),
    F.round(F.avg("trip_distance"), 2).alias("avg_distance_mi"),
    F.round(F.avg("trip_duration_mins"), 2).alias("avg_duration_mins"),
    F.round(F.avg("fare_amount"), 2).alias("avg_fare"),
    F.round(F.avg("speed_mph"), 2).alias("avg_speed_mph"),
).first()

print(f"Total trips: {stats['total_trips']:,}")
print(f"Avg distance: {stats['avg_distance_mi']} miles")
print(f"Avg duration: {stats['avg_duration_mins']} mins")
print(f"Avg fare: ${stats['avg_fare']}")
print(f"Avg speed: {stats['avg_speed_mph']} mph")

# Show potential issues already visible
print(f"\nüîç Potential Data Quality Issues:")

issues = df_sample.select(
    F.sum(F.when(F.col("fare_amount") <= 0, 1).otherwise(0)).alias("zero_or_negative_fares"),
    F.sum(F.when(F.col("trip_distance") <= 0, 1).otherwise(0)).alias("zero_distance_trips"),
    F.sum(F.when(F.col("speed_mph") > 100, 1).otherwise(0)).alias("impossibly_fast_trips"),
    F.sum(F.when(F.col("trip_duration_mins") > 180, 1).otherwise(0)).alias("very_long_trips_3hr_plus"),
).first()

for col, count in issues.asDict().items():
    if count > 0:
        print(f"   ‚ö†Ô∏è {col.replace('_', ' ').title()}: {count}")

print(f"\nüí° These are just obvious issues - DQX will find subtle patterns too!")


In [None]:
# Save to table for model training
catalog = spark.sql("SELECT current_catalog()").first()[0]
schema_name = "dqx_demo"
spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog}.{schema_name}")

table_name = f"{catalog}.{schema_name}.nyctaxi_sample"
df_sample.write.mode("overwrite").saveAsTable(table_name)

print(f"‚úÖ Data saved to: {table_name}")

# Setup model registry
registry_table = f"{catalog}.{schema_name}.anomaly_model_registry_nyctaxi"
spark.sql(f"DROP TABLE IF EXISTS {registry_table}")
print(f"üìã Model registry: {registry_table}")


---

## Section 3: Train Anomaly Detection Model

We'll train on numeric features that capture trip characteristics.


In [None]:
# Train anomaly detection model
print("üéØ Training anomaly detection model...\n")

# Select features that capture trip characteristics
feature_columns = [
    "trip_distance",
    "trip_duration_mins",
    "speed_mph",
    "fare_amount",
]

model_uri = anomaly_engine.train(
    df=spark.table(table_name),
    columns=feature_columns,
    model_name="nyctaxi_dq",
    registry_table=registry_table
)

print(f"‚úÖ Model trained!")
print(f"   Model URI: {model_uri}")

# Show registry
print(f"\nüìã Model Registry:")
display(
    spark.table(registry_table)
    .select(
        "identity.model_name",
        "training.columns",
        "training.training_rows",
        "identity.status"
    )
)


---

## Section 4: Find Anomalies

Now let's score all trips and see what DQX finds!


In [None]:
# Apply BOTH rule-based checks AND ML anomaly detection in one pass
from databricks.labs.dqx.rule import DQRowRule
from databricks.labs.dqx.check_funcs import is_in_range, is_not_less_than

print("üîç Applying Rule-Based Checks + ML Anomaly Detection...\n")

# Combine simple rules with ML anomaly detection
# NOTE: Rules are intentionally STRICT (catch only obvious errors)
# This lets ML demonstrate value by catching subtler issues
all_checks = [
    # Simple rule-based checks - only catch CLEARLY invalid data
    DQRowRule(
        name="impossible_speed",
        check_func=is_in_range,
        check_func_kwargs={"column": "speed_mph", "min_limit": 0, "max_limit": 150}  # 150mph = clearly impossible
    ),
    DQRowRule(
        name="negative_fare",
        check_func=is_not_less_than,
        check_func_kwargs={"column": "fare_amount", "limit": 0}  # Only negative fares
    ),
    DQRowRule(
        name="negative_distance",
        check_func=is_not_less_than,
        check_func_kwargs={"column": "trip_distance", "limit": 0}  # Only negative distances
    ),
    
    # ML anomaly detection (catches subtle issues rules miss)
    DQDatasetRule(
        check_func=has_no_anomalies,
        check_func_kwargs={
            "merge_columns": ["trip_id"],
            "model": "nyctaxi_dq",
            "include_contributions": True,
            "registry_table": registry_table
        }
    )
]

# Apply all checks in ONE pass - this is the recommended pattern!
df_scored = dq_engine.apply_checks(spark.table(table_name), all_checks)

# Now df_scored has BOTH:
# - _errors: contains BOTH rule violations AND anomaly failures
# - _info.anomaly: ML anomaly scores (from has_no_anomalies)

# Filter _errors to exclude anomaly check (we'll count that separately via _info.anomaly)
row_rule_errors_col = F.filter(
    F.col("_errors"),
    lambda x: x["function"] != "has_no_anomalies"
)

total = df_scored.count()
rule_violations = df_scored.filter(F.size(row_rule_errors_col) > 0).count()
ml_anomalies = df_scored.filter(F.col("_info.anomaly.score") >= 0.5).count()

print(f"‚úÖ Applied {len(all_checks)} checks in one pass!")
print(f"\nüìä Results:")
print(f"   Total trips: {total:,}")
print(f"   Rule violations: {rule_violations:,} ({rule_violations/total*100:.1f}%)")
print(f"   ML anomalies (score ‚â• 0.5): {ml_anomalies:,} ({ml_anomalies/total*100:.1f}%)")


In [None]:
# Show top anomalies
print("üö® Top 15 Anomalous Trips:\n")

top_anomalies = (
    df_scored
    .filter(F.col("_info.anomaly.score") >= 0.5)
    .orderBy(F.col("_info.anomaly.score").desc())
    .select(
        "trip_id",
        F.round("trip_distance", 1).alias("distance_mi"),
        F.round("trip_duration_mins", 1).alias("duration_mins"),
        F.round("speed_mph", 1).alias("speed_mph"),
        F.round("fare_amount", 2).alias("fare"),
        F.round("_info.anomaly.score", 3).alias("anomaly_score")
    )
    .limit(15)
)

display(top_anomalies)

print("\nüí° Look for patterns:")
print("   ‚Ä¢ Unusually high fares")
print("   ‚Ä¢ Impossible speeds (>100 mph in NYC traffic?)")
print("   ‚Ä¢ Very long durations with short distances")
print("   ‚Ä¢ Zero fares or distances")


---

## Section 5: Understand WHY Trips Are Anomalous

DQX provides **feature contributions** that explain which features made each trip unusual.


In [None]:
# Analyze top anomalies in detail
print("üîé Deep Dive: Why Are These Trips Anomalous?\n")

top_5 = (
    df_scored
    .filter(F.col("_info.anomaly.score") >= 0.5)
    .orderBy(F.col("_info.anomaly.score").desc())
    .select(
        "trip_id",
        "trip_distance",
        "trip_duration_mins",
        "speed_mph",
        "fare_amount",
        "pickup_zip",
        F.col("_info.anomaly.score").alias("score"),
        F.col("_info.anomaly.contributions").alias("contributions")
    )
    .limit(5)
    .collect()
)

for i, row in enumerate(top_5, 1):
    print(f"{'='*60}")
    print(f"Anomaly #{i} (Score: {row['score']:.3f})")
    print(f"{'='*60}")
    print(f"Trip Details:")
    print(f"   Distance: {row['trip_distance']:.1f} miles")
    print(f"   Duration: {row['trip_duration_mins']:.1f} mins")
    print(f"   Speed: {row['speed_mph']:.1f} mph")
    print(f"   Fare: ${row['fare_amount']:.2f}")
    print(f"   Pickup ZIP: {row['pickup_zip']}")
    
    print(f"\nTop Contributing Factors:")
    top_factors = []
    if row['contributions']:
        sorted_contrib = sorted(row['contributions'].items(), key=lambda x: abs(x[1]), reverse=True)
        for feature, value in sorted_contrib[:3]:
            print(f"   üìå {feature}: {abs(value)*100:.1f}% contribution")
            top_factors.append(feature)
    
    # Smarter interpretation based on contributing factors AND values
    print(f"\nüéØ Likely Issue:")
    if row['speed_mph'] > 100:
        print(f"   ‚Üí Impossible speed ({row['speed_mph']:.0f} mph) - GPS or timestamp error")
    elif row['fare_amount'] <= 0:
        print(f"   ‚Üí Zero/negative fare - payment system error")
    elif row['trip_duration_mins'] > 180:
        print(f"   ‚Üí Very long trip ({row['trip_duration_mins']:.0f} mins) - meter issue?")
    elif row['trip_distance'] <= 0 and row['fare_amount'] > 0:
        print(f"   ‚Üí Zero distance but charged ${row['fare_amount']:.2f} - GPS error")
    else:
        # Interpret based on top contributing factors
        interpretations = []
        if any('distance' in f for f in top_factors[:2]):
            interpretations.append(f"unusually long trip ({row['trip_distance']:.0f} mi)")
        if any('fare' in f for f in top_factors[:2]):
            interpretations.append(f"unusual fare (${row['fare_amount']:.0f})")
        if any('zip' in f for f in top_factors[:2]):
            interpretations.append(f"rare pickup location")
        if any('duration' in f for f in top_factors[:2]):
            interpretations.append(f"unusual duration ({row['trip_duration_mins']:.0f} mins)")
        if any('speed' in f for f in top_factors[:2]):
            interpretations.append(f"unusual speed ({row['speed_mph']:.0f} mph)")
        
        if interpretations:
            print(f"   ‚Üí Multi-factor: {' + '.join(interpretations)}")
            print(f"   ‚Üí This combination is rare in the training data")
        else:
            print(f"   ‚Üí Unusual combination of values")
    print()


In [None]:
# Analyze the combined results - Rules vs ML
print("üìä Rule-Based vs ML Detection Comparison\n")

# Add derived columns for easier analysis
# Filter _errors to exclude anomaly check (function="has_no_anomalies")
df_analysis = df_scored.withColumn(
    "row_rule_error_count",
    F.size(F.filter(F.col("_errors"), lambda x: x["function"] != "has_no_anomalies"))
).withColumn(
    "is_ml_anomaly",
    F.col("_info.anomaly.score") >= 0.5
)

# Count overlaps using the derived columns
total_rows = df_analysis.count()
rule_count = df_analysis.filter(F.col("row_rule_error_count") > 0).count()
ml_count = df_analysis.filter(F.col("is_ml_anomaly")).count()

both_flagged = df_analysis.filter((F.col("row_rule_error_count") > 0) & F.col("is_ml_anomaly")).count()
rules_only = df_analysis.filter((F.col("row_rule_error_count") > 0) & ~F.col("is_ml_anomaly")).count()
ml_only = df_analysis.filter((F.col("row_rule_error_count") == 0) & F.col("is_ml_anomaly")).count()

print("üî¥ DQX Rule-Based Checks (from _errors):")
print(f"   Rows with violations: {rule_count}")

print(f"\nüü° ML Anomaly Detection (from _info.anomaly):")
print(f"   Anomalies detected: {ml_count}")

print(f"\nüìà Overlap Analysis:")
print(f"   ‚Ä¢ Caught by BOTH rules AND ML: {both_flagged}")
print(f"   ‚Ä¢ Caught by rules ONLY: {rules_only}")
print(f"   ‚Ä¢ Caught by ML ONLY: {ml_only} ‚Üê This is the ML value-add!")

if ml_only > 0:
    pct_ml_only = (ml_only / ml_count) * 100
    print(f"\nüí° Key Insight:")
    print(f"   ML found {ml_only} issues ({pct_ml_only:.0f}% of all anomalies) that simple rules missed!")
    print(f"   These are subtle patterns like unusual trip+fare+location combinations.")

# Show a sample of ML-only anomalies
print(f"\nüîé Sample ML-Only Anomalies (passed rules but flagged by ML):")
display(
    df_analysis
    .filter((F.col("row_rule_error_count") == 0) & F.col("is_ml_anomaly"))
    .select(
        "trip_id",
        F.round("trip_distance", 1).alias("distance"),
        F.round("trip_duration_mins", 1).alias("duration"),
        F.round("fare_amount", 2).alias("fare"),
        F.round("_info.anomaly.score", 3).alias("anomaly_score")
    )
    .orderBy(F.col("anomaly_score").desc())
    .limit(5)
)


---

## Section 6: Compare Normal vs Anomalous Trips


In [None]:
# Statistical comparison
print("üìà Normal vs Anomalous Trip Comparison:\n")

normal = df_scored.filter(F.col("_info.anomaly.score") < 0.5)
anomalous = df_scored.filter(F.col("_info.anomaly.score") >= 0.5)

def get_stats(df, label):
    return df.select(
        F.lit(label).alias("type"),
        F.count("*").alias("count"),
        F.round(F.avg("trip_distance"), 2).alias("avg_distance"),
        F.round(F.avg("trip_duration_mins"), 2).alias("avg_duration"),
        F.round(F.avg("speed_mph"), 2).alias("avg_speed"),
        F.round(F.avg("fare_amount"), 2).alias("avg_fare"),
    )

comparison = get_stats(normal, "Normal").union(get_stats(anomalous, "Anomalous"))
display(comparison)

print("\nüí° Anomalous trips have noticeably different patterns!")


---

## Summary

### What We Found

DQX successfully detected **real data quality issues** in NYC Taxi data:

| Issue Type | Description | Likely Cause |
|------------|-------------|-------------|
| üöó Impossible speeds | >100 mph in NYC | GPS or timestamp error |
| üí∞ Zero/negative fares | $0 or negative amounts | Payment system error |
| üìç Zero distance trips | 0 miles but charged | GPS malfunction |
| ‚è±Ô∏è Extreme durations | 3+ hour trips | Meter left running |
| üíµ High fares | Unusually expensive trips | Long distance or data error |

### Key Takeaways

1. **Unsupervised detection works** - No labels needed to find problems
2. **Contributions explain WHY** - Not just "anomaly" but which features drove it
3. **Catches subtle patterns** - Not just simple rule violations

### Next Steps

- **Quarantine** anomalous trips for investigation
- **Combine** with rule-based checks for comprehensive DQ
- **Monitor** ongoing data for new anomaly patterns
- **Tune threshold** based on your tolerance (0.5 default, adjust up/down)

### Resources

- [DQX Documentation](https://databrickslabs.github.io/dqx)
- [Anomaly Detection Guide](https://databrickslabs.github.io/dqx/guide/anomaly_detection)
