In [None]:
from pyspark.sql.functions import col, mean, stddev
from databricks.sdk.runtime import dbutils, spark

catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
df = spark.read.table(f"{catalog}.{schema}.trips_raw")

# Add fare_per_mile column
df = df.withColumn("fare_per_mile", col("fare_amount") / col("trip_distance"))

# Filter out invalid trips
filtered = df.filter((col("trip_distance") > 0) & (col("fare_per_mile").isNotNull()))

# Compute mean and stddev
stats = filtered.select(mean("fare_per_mile"), stddev("fare_per_mile")).first()
mean_val, std_val = stats[0], stats[1]

# Flag outliers (z-score > 3)
anomalies = filtered.filter((col("fare_per_mile") > mean_val + 3 * std_val) |
                            (col("fare_per_mile") < mean_val - 3 * std_val))

anomalies.write.mode("overwrite").saveAsTable(f"{catalog}.{schema}.anomalous_trips")