In [0]:
%run /Shared/_init_azure_conn

In [0]:
# ============================================================
# Guided Capstone Step 2 ‚Äì Data Ingestion (Databricks CE version)
# Refactored to add bad-record routing (partition=B) and safer parsing
# ============================================================

from pyspark.sql import SparkSession
import pyspark.sql.types as T
from pyspark.sql.functions import input_file_name, regexp_extract, current_timestamp, lit
import json
from datetime import datetime

spark = SparkSession.builder.appName("guided_step2_ingestion").getOrCreate()

# === 1. Common schema (expanded) ===
schema = T.StructType([
    T.StructField("trade_dt", T.StringType()),          # Both
    T.StructField("rec_type", T.StringType()),          # ‚ÄúT‚Äù, ‚ÄúQ‚Äù, or ‚ÄúB‚Äù (bad record)
    T.StructField("symbol", T.StringType()),            # present in both
    T.StructField("exchange", T.StringType()),          # present in both
    T.StructField("event_tm", T.StringType()),          # present in both
    T.StructField("event_seq_nb", T.IntegerType()),     # present in both
    T.StructField("arrival_tm", T.StringType()),        # derived from ingestion timestamp
    T.StructField("trade_pr", T.DoubleType()),          # Trade only
    T.StructField("trade_size", T.IntegerType()),       # Trade only
    T.StructField("bid_pr", T.DoubleType()),            # Quote only
    T.StructField("bid_size", T.IntegerType()),         # Quote only
    T.StructField("ask_pr", T.DoubleType()),            # Quote only
    T.StructField("ask_size", T.IntegerType()),         # Quote only
    T.StructField("execution_id", T.StringType()),      # Trade only, may be null
    T.StructField("partition", T.StringType())          # ‚ÄúT‚Äù, ‚ÄúQ‚Äù, or ‚ÄúB‚Äù
])

# === helper to build a bad-record row ===
def bad_record():
    return {
        "trade_dt": None,
        "rec_type": "B",
        "symbol": None,
        "exchange": None,
        "event_tm": None,
        "event_seq_nb": None,
        "arrival_tm": datetime.utcnow().isoformat(),
        "trade_pr": None,
        "trade_size": None,
        "bid_pr": None,
        "bid_size": None,
        "ask_pr": None,
        "ask_size": None,
        "execution_id": None,
        "partition": "B"
    }

# === 2. CSV parser (best-guess, position-based, tolerant) ===
def parse_csv(line: str):
    try:
        # keep empties to preserve positions
        vals = [v.strip() for v in line.split(",")]
        if len(vals) < 7:
            return bad_record()

        trade_dt   = vals[0] or None
        arrival_tm = vals[1] or None
        rec_type   = (vals[2] or "").upper()

        # your current files put symbol at 3, event_tm at 4, seq at 5, exchange at 6
        symbol        = vals[3] or None
        event_tm      = vals[4] or None
        event_seq_nb  = int(vals[5]) if vals[5] else None
        exchange      = vals[6] or None

        # now branch by record type
        if rec_type == "T":
            # we don‚Äôt appear to have execution_id or trade_size in the current CSV sample,
            # so we leave them null and focus on trade_pr
            trade_pr   = float(vals[7]) if len(vals) > 7 and vals[7] else None
            trade_size = int(vals[8]) if len(vals) > 8 and vals[8] else None
            return {
                "trade_dt": trade_dt,
                "rec_type": "T",
                "symbol": symbol,
                "exchange": exchange,
                "event_tm": event_tm,
                "event_seq_nb": event_seq_nb,
                "arrival_tm": arrival_tm or datetime.utcnow().isoformat(),
                "trade_pr": trade_pr,
                "trade_size": trade_size,
                "bid_pr": None,
                "bid_size": None,
                "ask_pr": None,
                "ask_size": None,
                "execution_id": None,
                "partition": "T"
            }

        elif rec_type == "Q":
            # your original mapping: 7..10 are quote fields
            bid_pr   = float(vals[7]) if len(vals) > 7 and vals[7] else None
            bid_size = int(vals[8])   if len(vals) > 8 and vals[8] else None
            ask_pr   = float(vals[9]) if len(vals) > 9 and vals[9] else None
            ask_size = int(vals[10])  if len(vals) > 10 and vals[10] else None
            return {
                "trade_dt": trade_dt,
                "rec_type": "Q",
                "symbol": symbol,
                "exchange": exchange,
                "event_tm": event_tm,
                "event_seq_nb": event_seq_nb,
                "arrival_tm": arrival_tm or datetime.utcnow().isoformat(),
                "trade_pr": None,
                "trade_size": None,
                "bid_pr": bid_pr,
                "bid_size": bid_size,
                "ask_pr": ask_pr,
                "ask_size": ask_size,
                "execution_id": None,
                "partition": "Q"
            }
        else:
            # unknown rec_type
            return bad_record()

    except Exception:
        return bad_record()

# === 3. JSON parser (normalized to common event) ===
def parse_json(line: str):
    try:
        rec = json.loads(line)
        rec_type = rec.get("event_type") or rec.get("rec_type") or "B"
        rec_type = rec_type.upper()

        trade_dt      = rec.get("trade_dt") or rec.get("trade_date")
        symbol        = rec.get("symbol")
        exchange      = rec.get("exchange")
        event_tm      = rec.get("event_tm")
        event_seq_nb  = rec.get("event_seq_nb")
        arrival_tm    = rec.get("file_tm") or datetime.utcnow().isoformat()

        base = {
            "trade_dt": trade_dt,
            "rec_type": rec_type,
            "symbol": symbol,
            "exchange": exchange,
            "event_tm": event_tm,
            "event_seq_nb": int(event_seq_nb) if event_seq_nb is not None else None,
            "arrival_tm": arrival_tm,
            "trade_pr": None,
            "trade_size": None,
            "bid_pr": None,
            "bid_size": None,
            "ask_pr": None,
            "ask_size": None,
            "execution_id": rec.get("execution_id"),
            "partition": rec_type if rec_type in ("T", "Q") else "B"
        }

        if rec_type == "T":
            base["trade_pr"] = float(rec.get("trade_pr")) if rec.get("trade_pr") is not None else None
            base["trade_size"] = int(rec.get("trade_size")) if rec.get("trade_size") is not None else None
        elif rec_type == "Q":
            base["bid_pr"]   = float(rec.get("bid_pr")) if rec.get("bid_pr") is not None else None
            base["bid_size"] = int(rec.get("bid_size")) if rec.get("bid_size") is not None else None
            base["ask_pr"]   = float(rec.get("ask_pr")) if rec.get("ask_pr") is not None else None
            base["ask_size"] = int(rec.get("ask_size")) if rec.get("ask_size") is not None else None

        return base

    except Exception:
        return bad_record()

# === 4. Define all paths ===
# Base Azure Data Lake path (ABFS using SAS authentication)
base_path = f"abfss://{container_name}@{storage_account_name}.dfs.core.windows.net"
csv_path = f"{base_path}/data/csv/*/*/*.txt"
json_path = f"{base_path}/data/json/*/*/*.txt"
output_path = f"{base_path}/output_dir/"


print("‚úÖ CSV path:", csv_path)
print("‚úÖ JSON path:", json_path)
print("‚úÖ Output path:", output_path)


# === 5. Load, parse, DF-ify (modern, SAS-compatible) ===
csv_raw_df = spark.read.text(csv_path)
json_raw_df = spark.read.text(json_path)

csv_df = csv_raw_df.rdd.map(lambda r: parse_csv(r.value)).filter(lambda r: r is not None).toDF(schema=schema)
json_df = json_raw_df.rdd.map(lambda r: parse_json(r.value)).filter(lambda r: r is not None).toDF(schema=schema)

combined_df = csv_df.unionByName(json_df, allowMissingColumns=True)


# === 6. Add audit metadata columns ===
combined_df = (combined_df
    .withColumn("source_path", input_file_name())
    .withColumn("source_file", regexp_extract("source_path", r"([^/]+)$", 1))
    .withColumn("ingest_ts", current_timestamp())
)

# === 7. Record Count Audit (ABFS-safe) ===

print("CSV path:", csv_path)
print("JSON path:", json_path)
print("Output path:", output_path)

# ‚úÖ Use Spark‚Äôs modern reader instead of textFile() to respect SAS configs
csv_raw_df = spark.read.text(csv_path)
json_raw_df = spark.read.text(json_path)

raw_csv_count = csv_raw_df.count()
raw_json_count = json_raw_df.count()
raw_total = raw_csv_count + raw_json_count
post_total = combined_df.count()

partition_counts = combined_df.groupBy("partition").count().collect()
partition_summary = {row["partition"]: row["count"] for row in partition_counts}

print("=== Record Count Audit ===")
print(f"Raw CSV count:   {raw_csv_count}")
print(f"Raw JSON count:  {raw_json_count}")
print(f"Raw total:       {raw_total}")
print(f"Post-ingest total: {post_total}")
print(f"Partition breakdown: {partition_summary}")

if abs(raw_total - post_total) == 0:
    print("‚úÖ Record counts match ‚Äî no records dropped or added.")
else:
    print("‚ö†Ô∏è Mismatch detected! Investigate parser or schema drift.")


# === 8. Write partitioned output ===
combined_count = combined_df.count()
print("Combined Count:", combined_count)

if combined_count > 0:
    combined_df.groupBy("partition").count().show()
    combined_df.write.partitionBy("partition").mode("overwrite").parquet(output_path)
    print(f"‚úÖ Data written successfully to: {output_path}")
else:
    print("‚ö†Ô∏è No data to write ‚Äî check parser output.")

combined_df.printSchema()


In [0]:
from datetime import datetime
from pyspark.sql import types as T

# === 8. Record Count Audit & Logging ===

print("=== Record Count Audit (ABFS) ===")
print(f"CSV path:  {csv_path}")
print(f"JSON path: {json_path}")
print(f"Output path: {output_path}")

# ‚úÖ Pre-ingest raw file line counts (ABFS-safe)
raw_csv_count = spark.read.text(csv_path).count()
raw_json_count = spark.read.text(json_path).count()
raw_total = raw_csv_count + raw_json_count

# ‚úÖ Post-ingest DataFrame count
post_total = combined_df.count()

# ‚úÖ Partition breakdown (T, Q, B)
partition_counts = combined_df.groupBy("partition").count().collect()
partition_summary = {row["partition"]: row["count"] for row in partition_counts}

print(f"Raw CSV count:   {raw_csv_count}")
print(f"Raw JSON count:  {raw_json_count}")
print(f"Raw total:       {raw_total}")
print(f"Post-ingest total: {post_total}")
print(f"Partition breakdown: {partition_summary}")

if raw_total == post_total:
    print("‚úÖ Record counts match ‚Äî all records accounted for.")
    job_status = "success"
else:
    diff = raw_total - post_total
    print(f"‚ö†Ô∏è Mismatch detected! {diff} record(s) lost or added during ingestion.")
    job_status = "mismatch"

# === 9. Persist audit summary for job lineage ===

audit_data = [{
    "run_ts": datetime.utcnow().isoformat(), 
    "raw_csv_count": raw_csv_count,
    "raw_json_count": raw_json_count,
    "raw_total": raw_total,
    "post_total": post_total,
    "partition_summary": str(partition_summary),
    "output_path": output_path,
    "job_status": job_status,
    "processed_by": "guided_step2_ingestion"
}]

audit_schema = T.StructType([
    T.StructField("run_ts", T.StringType()),
    T.StructField("raw_csv_count", T.LongType()),
    T.StructField("raw_json_count", T.LongType()),
    T.StructField("raw_total", T.LongType()),
    T.StructField("post_total", T.LongType()),
    T.StructField("partition_summary", T.StringType()),
    T.StructField("output_path", T.StringType()),
    T.StructField("job_status", T.StringType()),
    T.StructField("processed_by", T.StringType())
])

audit_df = spark.createDataFrame(audit_data, schema=audit_schema)

# ‚úÖ ABFS-safe write (still uses SAS configuration)
audit_log_path = f"{output_path}/_audit_log"
audit_df.write.mode("append").parquet(audit_log_path)

print(f"üßæ Audit log updated at: {audit_log_path}")

In [0]:
# ============================================================
# Step 3 ‚Äì End-of-Day (EOD) Data Load
# Consistent with Step 2 ABFS/SAS setup
# Reads Step 2 output parquet partitions -> cleans, dedups,
# and writes EOD parquet partitioned by trade_dt
# ============================================================

from pyspark.sql import functions as F
from pyspark.sql.window import Window
from datetime import datetime

# === 0. Config ===
eod_dir = f"{output_path}/eod"

# === 1. Utility ‚Äì Keep latest arrival_tm per unique event key ===
def apply_latest(df):
    key_cols = ["trade_dt", "symbol", "exchange", "event_tm", "event_seq_nb"]
    w = Window.partitionBy(*key_cols).orderBy(F.col("arrival_tm").desc())
    return df.withColumn("rn", F.row_number().over(w)).filter(F.col("rn") == 1).drop("rn")

# === 2. Metadata-Driven Dataset Processing ===
datasets = {
    "T": {
        "name": "trade",
        "cols": ["trade_dt", "symbol", "exchange", "event_tm",
                 "event_seq_nb", "arrival_tm", "trade_pr"],
    },
    "Q": {
        "name": "quote",
        "cols": ["trade_dt", "symbol", "exchange", "event_tm",
                 "event_seq_nb", "arrival_tm", "bid_pr",
                 "bid_size", "ask_pr", "ask_size"],
    },
}

audit_records = []

for part, meta in datasets.items():
    print(f"\n=== Processing partition {part} ({meta['name']}) ===")

    src_base = f"{output_path}/partition={part}"
    tgt_path = f"{eod_dir}/{meta['name']}"

    # --- SAFE FULL-PARTITION READ ---
    # Collect valid subfolders (ignores _SUCCESS, temp dirs, etc.)
    valid_paths = [
        f.path for f in dbutils.fs.ls(src_base)
        if not (f.name.startswith("_") or f.name.startswith("."))
    ]

    if not valid_paths:
        print(f"‚ö†Ô∏è  No valid partitions found under {src_base}. Skipping.")
        continue

    # Read every available parquet folder (full history)
    df = spark.read.parquet(*valid_paths)

    # Select only the required columns (no date filtering)
    df_selected = df.select(*meta["cols"])

    # Filter to ensure only valid trade or quote records remain
    if part == "T":
        df_selected = df_selected.filter(F.col("trade_pr").isNotNull())
    elif part == "Q":
        df_selected = df_selected.filter(F.col("bid_pr").isNotNull() & F.col("ask_pr").isNotNull())

    # Deduplicate latest arrival_tm per unique key
    df_deduped = apply_latest(df_selected)

    # === Sanity Check: Count null values dropped per symbol ===
    if part == "T":
        df_nulls = df.filter(F.col("trade_pr").isNull()) \
                    .groupBy("symbol", "exchange") \
                    .agg(F.count("*").alias("null_trade_pr_count")) \
                    .orderBy(F.desc("null_trade_pr_count"))

        total_nulls = df_nulls.agg(F.sum("null_trade_pr_count")).collect()[0][0] or 0
        print(f"‚ö†Ô∏è  Found {total_nulls:,} null trade_pr rows in {meta['name']} dataset")

        if total_nulls > 0:
            display(df_nulls.limit(20))  # optional; shows top 20 symbols with missing prices

    elif part == "Q":
        df_nulls = df.filter(F.col("bid_pr").isNull() | F.col("ask_pr").isNull()) \
                    .groupBy("symbol", "exchange") \
                    .agg(F.count("*").alias("null_quote_count")) \
                    .orderBy(F.desc("null_quote_count"))

        total_nulls = df_nulls.agg(F.sum("null_quote_count")).collect()[0][0] or 0
        print(f"‚ö†Ô∏è  Found {total_nulls:,} null bid/ask rows in {meta['name']} dataset")

        if total_nulls > 0:
            display(df_nulls.limit(20))

    # Write out all trade_dt partitions automatically
    df_deduped = df_deduped.withColumn("trade_dt_copy", F.col("trade_dt"))
    df_deduped.write.mode("overwrite").partitionBy("trade_dt").parquet(tgt_path)

    # Collect audit info
    record = {
        "partition": part,
        "dataset": meta["name"],
        "input_count": df_selected.count(),
        "output_count": df_deduped.count(),
        "input_path": src_base,
        "output_path": tgt_path,
        "status": "success",
        "run_ts": datetime.utcnow().isoformat()
    }
    audit_records.append(record)

    print(f"‚úÖ {meta['name']} written to {tgt_path}")

# === 3. Audit Log Write ===
audit_df = spark.createDataFrame(audit_records)
audit_log_path = f"{eod_dir}/_audit_log"
audit_df.write.mode("append").parquet(audit_log_path)

print("\nüßæ Audit log updated:", audit_log_path)
display(audit_df)
print("\nüéØ EOD load complete for all available trade_dt partitions.")