In [0]:
%pip install paramiko python-box pyyaml

In [0]:
%restart_python

In [0]:
import os
import stat
import yaml
import paramiko
from box import Box
from datetime import datetime, timezone
import hashlib
import shlex

from pyspark.sql import functions as F
from pyspark.sql import types as T

from helper import CustomLogger

In [0]:
logger = CustomLogger()

# Config + Secrets (kept consistent with existing pipeline)
with open("gcp_config.yaml", "rb") as f:
    cfg = Box(yaml.safe_load(f))

asset_scope = cfg.institution.secure_assets["scope"]

host = dbutils.secrets.get(scope=asset_scope, key=cfg.pdp.secret["keys"]["host"])
user = dbutils.secrets.get(scope=asset_scope, key=cfg.pdp.secret["keys"]["user"])
password = dbutils.secrets.get(scope=asset_scope, key=cfg.pdp.secret["keys"]["password"])

remote_folder = "./receive"
source_system = "NSC"

CATALOG = "staging_sst_01"
DEFAULT_SCHEMA = "default"
MANIFEST_TABLE = f"{CATALOG}.{DEFAULT_SCHEMA}.ingestion_manifest"
QUEUE_TABLE = f"{CATALOG}.{DEFAULT_SCHEMA}.pending_ingest_queue"

TMP_DIR = "./tmp/pdp_sftp_stage"

logger.info("SFTP secured assets loaded successfully.")

In [0]:
def connect_sftp(host: str, username: str, password: str, port: int = 22):
    """
    Return (transport, sftp_client). Caller must close both.
    """
    transport = paramiko.Transport((host, port))
    transport.connect(username=username, password=password)
    sftp = paramiko.SFTPClient.from_transport(transport)
    print(f"Connected successfully to {host}")
    return transport, sftp

In [0]:
def ensure_tables():
    """
    Create required delta tables if missing.
    - ingestion_manifest: includes file_fingerprint for idempotency
    - pending_ingest_queue: holds local tmp path so downstream doesn't connect to SFTP again
    """
    spark.sql(
        f"""
        CREATE TABLE IF NOT EXISTS {MANIFEST_TABLE} (
          file_fingerprint STRING,
          source_system STRING,
          sftp_path STRING,
          file_name STRING,
          file_size BIGINT,
          file_modified_time TIMESTAMP,
          ingested_at TIMESTAMP,
          processed_at TIMESTAMP,
          status STRING,
          error_message STRING
        )
        USING DELTA
        """
    )

    spark.sql(
        f"""
        CREATE TABLE IF NOT EXISTS {QUEUE_TABLE} (
          file_fingerprint STRING,
          source_system STRING,
          sftp_path STRING,
          file_name STRING,
          file_size BIGINT,
          file_modified_time TIMESTAMP,
          local_tmp_path STRING,
          queued_at TIMESTAMP
        )
        USING DELTA
        """
    )

In [0]:
def list_receive_files(sftp: paramiko.SFTPClient, remote_dir: str):
    """
    List non-directory files in remote_dir with metadata.
    Returns list[dict] with keys: source_system, sftp_path, file_name, file_size, file_modified_time
    """
    results = []
    for attr in sftp.listdir_attr(remote_dir):
        if stat.S_ISDIR(attr.st_mode):
            continue

        file_name = attr.filename
        file_size = int(attr.st_size) if attr.st_size is not None else None
        mtime = datetime.fromtimestamp(int(attr.st_mtime), tz=timezone.utc) if attr.st_mtime else None

        results.append(
            {
                "source_system": source_system,
                "sftp_path": remote_dir,
                "file_name": file_name,
                "file_size": file_size,
                "file_modified_time": mtime,
            }
        )
    return results

In [0]:
def build_listing_df(file_rows):
    schema = T.StructType(
        [
            T.StructField("source_system", T.StringType(), False),
            T.StructField("sftp_path", T.StringType(), False),
            T.StructField("file_name", T.StringType(), False),
            T.StructField("file_size", T.LongType(), True),
            T.StructField("file_modified_time", T.TimestampType(), True),
        ]
    )

    df = spark.createDataFrame(file_rows, schema=schema)

    # Stable fingerprint from metadata (file version identity)
    # Note: cast mtime to string in a consistent format to avoid subtle timestamp formatting diffs.
    df = df.withColumn(
        "file_fingerprint",
        F.sha2(
            F.concat_ws(
                "||",
                F.col("source_system"),
                F.col("sftp_path"),
                F.col("file_name"),
                F.coalesce(F.col("file_size").cast("string"), F.lit("")),
                F.coalesce(F.date_format(F.col("file_modified_time"), "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), F.lit("")),
            ),
            256,
        ),
    )

    return df

In [0]:
def upsert_new_to_manifest(df_listing):
    """
    Insert NEW rows for unseen fingerprints only.
    """
    df_manifest_insert = (
        df_listing.select(
            "file_fingerprint",
            "source_system",
            "sftp_path",
            "file_name",
            "file_size",
            "file_modified_time",
        )
        .withColumn("ingested_at", F.lit(None).cast("timestamp"))
        .withColumn("processed_at", F.lit(None).cast("timestamp"))
        .withColumn("status", F.lit("NEW"))
        .withColumn("error_message", F.lit(None).cast("string"))
    )

    df_manifest_insert.createOrReplaceTempView("incoming_manifest_rows")

    spark.sql(
        f"""
        MERGE INTO {MANIFEST_TABLE} AS t
        USING incoming_manifest_rows AS s
        ON t.file_fingerprint = s.file_fingerprint
        WHEN NOT MATCHED THEN INSERT *
        """
    )

In [0]:
def get_files_to_queue(df_listing):
    """
    Return files that should be queued for downstream processing.

    Criteria:
      - present in current SFTP listing (df_listing)
      - exist in manifest with status = 'NEW'
      - NOT already present in pending_ingest_queue
    """
    manifest_new = (
        spark.table(MANIFEST_TABLE)
        .select("file_fingerprint", "status")
        .where(F.col("status") == F.lit("NEW"))
        .select("file_fingerprint")
    )

    already_queued = spark.table(QUEUE_TABLE).select("file_fingerprint").distinct()

    # Only queue files that are:
    #   in current listing AND in manifest NEW AND not in queue
    to_queue = (
        df_listing.join(manifest_new, on="file_fingerprint", how="inner")
                 .join(already_queued, on="file_fingerprint", how="left_anti")
    )
    return to_queue


In [0]:
def _hash_file(path, algo="sha256", chunk_size=8 * 1024 * 1024):
    h = hashlib.new(algo)
    with open(path, "rb") as f:
        while True:
            b = f.read(chunk_size)
            if not b:
                break
            h.update(b)
    return h.hexdigest()

def _remote_hash(ssh, remote_path, algo="sha256"):
    cmd = None
    if algo.lower() == "sha256":
        cmd = f"sha256sum -- {shlex.quote(remote_path)}"
    elif algo.lower() == "md5":
        cmd = f"md5sum -- {shlex.quote(remote_path)}"
    else:
        return None

    try:
        _, stdout, stderr = ssh.exec_command(cmd, timeout=300)
        out = stdout.read().decode("utf-8", "replace").strip()
        err = stderr.read().decode("utf-8", "replace").strip()
        if err:
            return None
        # Format: "<hash>  <filename>"
        return out.split()[0]
    except Exception:
        return None
    
def download_sftp_atomic(
    sftp,
    remote_path,
    local_path,
    *,
    chunk: int = 150,
    verify="size", # "size" | "sha256" | "md5" | None
    ssh_for_remote_hash=None, # paramiko.SSHClient if you want remote hash verify
    progress=True
):
    """
    Atomic + resumable SFTP download that never trims data in situ.
    Writes to local_path + '.part' and moves into place after verification.
    """
    remote_size = sftp.stat(remote_path).st_size
    tmp_path = f"{local_path}.part"
    chunk_size = chunk * 1024 * 1024
    offset = 0
    if os.path.exists(tmp_path):
        part_size = os.path.getsize(tmp_path)
        # If local .part is larger than remote, start fresh.
        if part_size <= remote_size:
            offset = part_size
        else:
            os.remove(tmp_path)

    # Open remote and local
    with sftp.file(remote_path, "rb") as rf:
        try:
            try:
                rf.set_pipelined(True)
            except Exception:
                pass

            if offset:
                rf.seek(offset)

            # Append if resuming, write if fresh
            with open(tmp_path, "ab" if offset else "wb") as lf:
                transferred = offset

                while transferred < remote_size:
                    to_read = min(chunk_size, remote_size - transferred)
                    data = rf.read(to_read)
                    if not data:
                        #don't accept short-read silently
                        raise IOError(
                            f"Short read at {transferred:,} of {remote_size:,} bytes"
                        )
                    lf.write(data)
                    transferred += len(data)
                    if progress and remote_size:
                        print(f"{transferred / remote_size:.2%} transferred...")
                lf.flush()
                os.fsync(lf.fileno())

        finally:
            # SFTPFile closed by context manager
            pass

    # Mandatory size verification
    local_size = os.path.getsize(tmp_path)
    if local_size != remote_size:
        raise IOError(
            f"Post-download size mismatch (local {local_size:,}, remote {remote_size:,})"
        )

    if verify in {"sha256", "md5"}:
        algo = verify
        local_hash = _hash_file(tmp_path, algo=algo)
        remote_hash = None
        if ssh_for_remote_hash is not None:
            remote_hash = _remote_hash(ssh_for_remote_hash, remote_path, algo=algo)

        if remote_hash and (remote_hash != local_hash):
            # Clean up .part so next run starts fresh
            try:
                os.remove(tmp_path)
            except Exception:
                pass
            raise IOError(
                f"{algo.upper()} mismatch: local={local_hash} remote={remote_hash}"
            )

    # Move atomically into place
    os.replace(tmp_path, local_path)
    if progress:
        print("Download complete (atomic & verified).")


In [0]:
def download_new_files_and_queue(sftp: paramiko.SFTPClient, df_new):
    """
    Download each new file to /tmp and upsert into pending_ingest_queue.
    """
    os.makedirs(TMP_DIR, exist_ok=True)

    # Collect is OK if you expect modest number of files. If you expect thousands, we can paginate and stream.
    rows = df_new.select(
        "file_fingerprint",
        "source_system",
        "sftp_path",
        "file_name",
        "file_size",
        "file_modified_time",
    ).collect()

    queued = []
    for r in rows:
        fp = r["file_fingerprint"]
        sftp_path = r["sftp_path"]
        file_name = r["file_name"]

        remote_path = f"{sftp_path.rstrip('/')}/{file_name}"
        local_path = os.path.join(TMP_DIR, f"{fp}__{file_name}")

        # If local already exists (e.g., rerun), skip re-download
        if not os.path.exists(local_path):
            print(f"Downloading new file from SFTP: {remote_path} -> {local_path}")
            logger.info(f"Downloading new file from SFTP: {remote_path} -> {local_path}")
            #sftp.get(remote_path, local_path)
            download_sftp_atomic(sftp, remote_path, local_path, chunk = 150)
        else:
            print(f"Skipping download, file already exists: {local_path}")
            logger.info(f"Local file already staged, skipping download: {local_path}")

        queued.append(
            {
                "file_fingerprint": fp,
                "source_system": r["source_system"],
                "sftp_path": sftp_path,
                "file_name": file_name,
                "file_size": r["file_size"],
                "file_modified_time": r["file_modified_time"],
                "local_tmp_path": local_path,
                "queued_at": datetime.now(timezone.utc),
            }
        )

    if not queued:
        return 0

    qschema = T.StructType(
        [
            T.StructField("file_fingerprint", T.StringType(), False),
            T.StructField("source_system", T.StringType(), False),
            T.StructField("sftp_path", T.StringType(), False),
            T.StructField("file_name", T.StringType(), False),
            T.StructField("file_size", T.LongType(), True),
            T.StructField("file_modified_time", T.TimestampType(), True),
            T.StructField("local_tmp_path", T.StringType(), False),
            T.StructField("queued_at", T.TimestampType(), False),
        ]
    )

    df_queue = spark.createDataFrame(queued, schema=qschema)
    df_queue.createOrReplaceTempView("incoming_queue_rows")

    # Upsert into queue (idempotent by fingerprint)

    spark.sql(
        f"""
        MERGE INTO {QUEUE_TABLE} AS t
        USING incoming_queue_rows AS s
        ON t.file_fingerprint = s.file_fingerprint
        WHEN MATCHED THEN UPDATE SET
        t.local_tmp_path = s.local_tmp_path,
        t.queued_at = s.queued_at
        WHEN NOT MATCHED THEN INSERT *
        """
    )


    return len(queued)

In [0]:
transport = None
sftp = None

try:
    ensure_tables()

    transport, sftp = connect_sftp(host, user, password)
    logger.info(f"Connected to SFTP host={host} and scanning folder={remote_folder}")

    file_rows = list_receive_files(sftp, remote_folder)
    if not file_rows:
        logger.info(f"No files found in SFTP folder: {remote_folder}. Exiting (no-op).")
        dbutils.notebook.exit("NO_FILES")

    df_listing = build_listing_df(file_rows)

    # 1) Ensure everything on SFTP is at least represented in manifest as NEW
    upsert_new_to_manifest(df_listing)

    # 2) Queue anything that is still NEW and not already queued
    df_to_queue = get_files_to_queue(df_listing)

    to_queue_count = df_to_queue.count()
    if to_queue_count == 0:
        logger.info("No files to queue: either nothing is NEW, or NEW files are already queued. Exiting (no-op).")
        dbutils.notebook.exit("QUEUED_FILES=0")

    logger.info(f"Queuing {to_queue_count} NEW-unqueued file(s) to {QUEUE_TABLE} and staging locally.")
    queued_count = download_new_files_and_queue(sftp, df_to_queue)

    logger.info(f"Queued {queued_count} file(s) for downstream processing in {QUEUE_TABLE}.")
    dbutils.notebook.exit(f"QUEUED_FILES={queued_count}")

finally:
    try:
        if sftp is not None:
            sftp.close()
    except Exception:
        pass
    try:
        if transport is not None:
            transport.close()
    except Exception:
        pass
