In [0]:
# Databricks notebook source
# Script 4 — 04_per_institution_bronze_ingest
#
# Purpose:
#   Consume institution_ingest_plan (created by Script 3), and for each (file × institution):
#     - get bearer token from SST staging using X-API-KEY (from Databricks secrets)
#     - call /api/v1/institutions/pdp-id/{pdp_id} to resolve institution name
#     - map name -> schema prefix via databricksify_inst_name()
#     - locate <prefix>_bronze schema in staging_sst_02
#     - choose a volume in that schema containing "bronze"
#     - filter rows by institution id (exactly like current script)
#     - write to bronze volume using helper.process_and_save_file (exact same ingestion method)
#   After all institutions for a file are processed, update ingestion_manifest:
#     - BRONZE_WRITTEN if all institution ingests succeeded (or were already present)
#     - FAILED if any error occurred for that file (store error_message)
#
# Constraints:
#   - NO SFTP connection (uses staged local files from Script 1/3)
#   - Uses existing ingestion function + behavior from current script


In [0]:
%pip install pandas python-box pyyaml requests paramiko
%restart_python

In [0]:
import os
import re
import yaml
import requests
import pandas as pd
from box import Box
from datetime import datetime, timezone
import paramiko

from pyspark.sql import functions as F
from pyspark.sql import types as T

from helper import process_and_save_file, CustomLogger


In [0]:
logger = CustomLogger()

# COMMAND ----------

# ---------------------------
# Config + constants
# ---------------------------
with open("gcp_config.yaml", "rb") as f:
    cfg = Box(yaml.safe_load(f))

CATALOG = "staging_sst_01"
DEFAULT_SCHEMA = "default"

PLAN_TABLE = f"{CATALOG}.{DEFAULT_SCHEMA}.institution_ingest_plan"
MANIFEST_TABLE = f"{CATALOG}.{DEFAULT_SCHEMA}.ingestion_manifest"

SST_BASE_URL = "https://staging-sst.datakind.org"
SST_TOKEN_ENDPOINT = f"{SST_BASE_URL}/api/v1/token-from-api-key"
INSTITUTION_LOOKUP_PATH = "/api/v1/institutions/pdp-id/{pdp_id}"

# IMPORTANT: set these two to your actual secret scope + key name(s)
SST_SECRET_SCOPE = cfg.institution.secure_assets["scope"]
SST_API_KEY_SECRET_KEY = "sst_staging_api_key"  # <-- update if your secret key is named differently
SST_API_KEY = dbutils.secrets.get(scope=SST_SECRET_SCOPE, key=SST_API_KEY_SECRET_KEY).strip()
if not SST_API_KEY:
    raise RuntimeError(f"Empty SST API key from secrets: scope={SST_SECRET_SCOPE} key={SST_API_KEY_SECRET_KEY}")

_session = requests.Session()
_session.headers.update({"accept": "application/json"})

_bearer_token = None
_institution_cache: dict[str, dict] = {}

In [0]:
def output_file_name_from_sftp(file_name: str) -> str:
    return f"{os.path.basename(file_name).split('.')[0]}.csv"

# Column normalization + renames (kept identical to current script)
def normalize_col(name: str) -> str:
    name = name.strip().lower()
    name = re.sub(r"[^a-z0-9_]", "_", name)
    name = re.sub(r"_+", "_", name)
    name = name.strip("_")
    return name

RENAMES = {
    "attemptedgatewaymathyear1": "attempted_gateway_math_year_1",
    "attemptedgatewayenglishyear1": "attempted_gateway_english_year_1",
    "completedgatewaymathyear1": "completed_gateway_math_year_1",
    "completedgatewayenglishyear1": "completed_gateway_english_year_1",
    "gatewaymathgradey1": "gateway_math_grade_y_1",
    "gatewayenglishgradey1": "gateway_english_grade_y_1",
    "attempteddevmathy1": "attempted_dev_math_y_1",
    "attempteddevenglishy1": "attempted_dev_english_y_1",
    "completeddevmathy1": "completed_dev_math_y_1",
    "completeddevenglishy1": "completed_dev_english_y_1",
}

# Provided by you
def databricksify_inst_name(inst_name: str) -> str:
    """
    Follow DK standardized rules for naming conventions used in Databricks.
    """
    name = inst_name.lower()
    dk_replacements = {
        "community technical college": "ctc",
        "community college": "cc",
        "of science and technology": "st",
        "university": "uni",
        "college": "col",
    }

    for old, new in dk_replacements.items():
        name = name.replace(old, new)

    special_char_replacements = {" & ": " ", "&": " ", "-": " "}
    for old, new in special_char_replacements.items():
        name = name.replace(old, new)

    final_name = name.replace(" ", "_")

    pattern = "^[a-z0-9_]*$"
    if not re.match(pattern, final_name):
        raise ValueError("Unexpected character found in Databricks compatible name.")
    return final_name

In [0]:
def fetch_bearer_token() -> str:
    """
    Fetch bearer token from API key using X-API-KEY header.
    Assumes token endpoint returns JSON containing one of: access_token, token, bearer_token, jwt.
    """
    resp = _session.post(
        SST_TOKEN_ENDPOINT,
        headers={"accept": "application/json", "X-API-KEY": SST_API_KEY},
        timeout=30,
    )
    if resp.status_code == 401:
        raise PermissionError("Unauthorized calling token endpoint (check X-API-KEY secret).")
    resp.raise_for_status()

    data = resp.json()
    for k in ["access_token", "token", "bearer_token", "jwt"]:
        v = data.get(k)
        if isinstance(v, str) and v.strip():
            return v.strip()

    raise ValueError(f"Token endpoint response missing expected token field. Keys={list(data.keys())}")

def ensure_auth():
    global _bearer_token
    if _bearer_token is None:
        _bearer_token = fetch_bearer_token()
        _session.headers.update({"Authorization": f"Bearer {_bearer_token}"})

def refresh_auth():
    global _bearer_token
    _bearer_token = fetch_bearer_token()
    _session.headers.update({"Authorization": f"Bearer {_bearer_token}"})


In [0]:
def fetch_institution_by_pdp_id(pdp_id: str) -> dict:
    """
    Resolve institution for PDP id. Cached within run.
    Refresh token once on 401.
    """
    pid = str(pdp_id).strip()
    if pid in _institution_cache:
        return _institution_cache[pid]

    ensure_auth()

    url = SST_BASE_URL + INSTITUTION_LOOKUP_PATH.format(pdp_id=pid)
    resp = _session.get(url, timeout=30)

    if resp.status_code == 401:
        refresh_auth()
        resp = _session.get(url, timeout=30)

    if resp.status_code == 404:
        raise ValueError(f"Institution PDP ID not found in SST staging: {pid}")

    resp.raise_for_status()
    data = resp.json()
    _institution_cache[pid] = data
    return data


In [0]:

_schema_cache: set[str] | None = None
_bronze_volume_cache: dict[str, str] = {}  # key: f"{catalog}.{schema}" -> volume_name

def list_schemas_in_catalog(catalog: str) -> set[str]:
    global _schema_cache
    if _schema_cache is None:
        rows = spark.sql(f"SHOW SCHEMAS IN {catalog}").collect()
        _schema_cache = {r["databaseName"] for r in rows}
    return _schema_cache

def find_bronze_schema(catalog: str, inst_prefix: str) -> str:
    target = f"{inst_prefix}_bronze"
    schemas = list_schemas_in_catalog(catalog)
    if target not in schemas:
        raise ValueError(f"Bronze schema not found: {catalog}.{target}")
    return target

def find_bronze_volume_name(catalog: str, schema: str) -> str:
    key = f"{catalog}.{schema}"
    if key in _bronze_volume_cache:
        return _bronze_volume_cache[key]

    vols = spark.sql(f"SHOW VOLUMES IN {catalog}.{schema}").collect()
    if not vols:
        raise ValueError(f"No volumes found in {catalog}.{schema}")

    # Usually "volume_name", but be defensive
    def _get_vol_name(row):
        d = row.asDict()
        for k in ["volume_name", "volumeName", "name"]:
            if k in d:
                return d[k]
        return list(d.values())[0]

    vol_names = [_get_vol_name(v) for v in vols]
    bronze_like = [v for v in vol_names if "bronze" in v.lower()]
    if bronze_like:
        _bronze_volume_cache[key] = bronze_like[0]
        return bronze_like[0]

    raise ValueError(f"No volume containing 'bronze' found in {catalog}.{schema}. Volumes={vol_names}")


In [0]:
def update_manifest(file_fingerprint: str, status: str, error_message: str | None):
    """
    Update ingestion_manifest for this file_fingerprint.
    Assumes Script 1 inserted status=NEW already.
    """
    now_ts = datetime.now(timezone.utc)

    # ingested_at only set when we finish BRONZE_WRITTEN
    row = {
        "file_fingerprint": file_fingerprint,
        "status": status,
        "error_message": error_message,
        "ingested_at": now_ts if status == "BRONZE_WRITTEN" else None,
        "processed_at": now_ts,
    }

    schema = T.StructType(
        [
            T.StructField("file_fingerprint", T.StringType(), False),
            T.StructField("status", T.StringType(), False),
            T.StructField("error_message", T.StringType(), True),
            T.StructField("ingested_at", T.TimestampType(), True),
            T.StructField("processed_at", T.TimestampType(), False),
        ]
    )
    df = spark.createDataFrame([row], schema=schema)
    df.createOrReplaceTempView("manifest_updates")

    spark.sql(
        f"""
        MERGE INTO {MANIFEST_TABLE} AS t
        USING manifest_updates AS s
        ON t.file_fingerprint = s.file_fingerprint
        WHEN MATCHED THEN UPDATE SET
          t.status = s.status,
          t.error_message = s.error_message,
          t.ingested_at = COALESCE(s.ingested_at, t.ingested_at),
          t.processed_at = s.processed_at
        """
    )


In [0]:
if not spark.catalog.tableExists(PLAN_TABLE):
    logger.info(f"Plan table not found: {PLAN_TABLE}. Exiting (no-op).")
    dbutils.notebook.exit("NO_PLAN_TABLE")

if not spark.catalog.tableExists(MANIFEST_TABLE):
    raise RuntimeError(f"Manifest table missing: {MANIFEST_TABLE}")

plan_df = spark.table(PLAN_TABLE)
if plan_df.limit(1).count() == 0:
    logger.info("institution_ingest_plan is empty. Exiting (no-op).")
    dbutils.notebook.exit("NO_WORK_ITEMS")

manifest_df = spark.table(MANIFEST_TABLE).select("file_fingerprint", "status")
plan_new_df = (
    plan_df.join(manifest_df, on="file_fingerprint", how="inner")
    .where(F.col("status") == F.lit("NEW"))
)
display(plan_new_df)
if plan_new_df.limit(1).count() == 0:
    logger.info("No planned work items where manifest status=NEW. Exiting (no-op).")
    dbutils.notebook.exit("NO_NEW_TO_INGEST")

# Collect file groups
file_groups = (
    plan_new_df.select(
        "file_fingerprint",
        "file_name",
        "local_path",
        "inst_col",
        "file_size",
        "file_modified_time",
    )
    .distinct()
    .collect()
)

logger.info(f"Preparing to ingest {len(file_groups)} NEW file(s).")


In [0]:
# ---------------------------
# Main per-file ingest loop
# ---------------------------
processed_files = 0
failed_files = 0
skipped_files = 0

for fg in file_groups:
    fp = fg["file_fingerprint"]
    sftp_file_name = fg["file_name"]
    local_path = fg["local_path"]
    inst_col = fg["inst_col"]

    if not local_path or not os.path.exists(local_path):
        err = f"Staged local file missing for fp={fp}: {local_path}"
        logger.error(err)
        update_manifest(fp, status="FAILED", error_message=err[:8000])
        failed_files += 1
        continue

    try:
        df_full = pd.read_csv(local_path, on_bad_lines="warn")
        df_full = df_full.rename(columns={c: normalize_col(c) for c in df_full.columns})
        df_full = df_full.rename(columns=RENAMES)

        if inst_col not in df_full.columns:
            err = f"Expected institution column '{inst_col}' not found after normalization/renames for file={sftp_file_name} fp={fp}"
            logger.error(err)
            update_manifest(fp, status="FAILED", error_message=err[:8000])
            failed_files += 1
            continue

        inst_ids = (
            plan_new_df.where(F.col("file_fingerprint") == fp)
            .select("institution_id")
            .distinct()
            .collect()
        )
        inst_ids = [r["institution_id"] for r in inst_ids]

        if not inst_ids:
            logger.info(f"No institution_ids in plan for file={sftp_file_name} fp={fp}. Marking BRONZE_WRITTEN (no-op).")
            update_manifest(fp, status="BRONZE_WRITTEN", error_message=None)
            skipped_files += 1
            continue

        # Aggregate errors at file-level
        file_errors = []

        for inst_id in inst_ids:
            try:
                filtered_df = df_full[df_full[inst_col] == int(inst_id)].reset_index(drop=True)

                if filtered_df.empty:
                    logger.info(f"file={sftp_file_name} fp={fp}: institution {inst_id} has 0 rows; skipping.")
                    continue

                # Resolve institution -> name
                inst_info = fetch_institution_by_pdp_id(inst_id)
                inst_name = inst_info.get("name")
                if not inst_name:
                    raise ValueError(f"SST API returned no 'name' for pdp_id={inst_id}. Response={inst_info}")

                inst_prefix = databricksify_inst_name(inst_name)

                # Find bronze schema + volume
                bronze_schema = find_bronze_schema(CATALOG, inst_prefix)
                bronze_volume_name = find_bronze_volume_name(CATALOG, bronze_schema)
                volume_dir = f"/Volumes/{CATALOG}/{bronze_schema}/{bronze_volume_name}"

                # Output naming rule (same as current script)
                out_file_name = output_file_name_from_sftp(sftp_file_name)
                full_path = os.path.join(volume_dir, out_file_name)

                # Idempotency check
                if os.path.exists(full_path):
                    logger.info(f"file={sftp_file_name} inst={inst_id}: already exists in {volume_dir}; skipping write.")
                    continue

                logger.info(f"file={sftp_file_name} inst={inst_id}: writing to {volume_dir} as {out_file_name}")
                process_and_save_file(volume_dir=volume_dir, file_name=out_file_name, df=filtered_df)
                logger.info(f"file={sftp_file_name} inst={inst_id}: write complete.")

            except Exception as e:
                msg = f"inst_ingest_failed file={sftp_file_name} fp={fp} inst={inst_id}: {e}"
                logger.exception(msg)
                file_errors.append(msg)

        if file_errors:
            err = " | ".join(file_errors)[:8000]
            update_manifest(fp, status="FAILED", error_message=err)
            failed_files += 1
        else:
            update_manifest(fp, status="BRONZE_WRITTEN", error_message=None)
            processed_files += 1

    except Exception as e:
        msg = f"fatal_file_error file={sftp_file_name} fp={fp}: {e}"
        logger.exception(msg)
        update_manifest(fp, status="FAILED", error_message=msg[:8000])
        failed_files += 1

logger.info(f"Done. processed_files={processed_files}, failed_files={failed_files}, skipped_files={skipped_files}")
dbutils.notebook.exit(f"PROCESSED={processed_files};FAILED={failed_files};SKIPPED={skipped_files}")
