In [0]:
"""
This notebook ingests course and cohort data for the Student Success Tool (SST) pipeline.

It reads data from CSV files stored in a Google Cloud Storage (GCS) bucket, 
performs schema validation using the `pdp` library, and writes the validated data 
to Delta Lake tables in Databricks Unity Catalog.

The notebook is designed to run within a Databricks environment as a job task, leveraging Databricks 
utilities for widget input, job task values, and Spark session management.

This is a POC notebook, it is advised to refactor to .py and add tests before using in production.

"""
import logging
import os

from databricks.connect import DatabricksSession
from databricks.sdk import WorkspaceClient
from databricks.sdk.runtime import dbutils
from google.cloud import storage
from email.headerregistry import Address

import student_success_tool.dataio as dataio
from student_success_tool.schemas import pdp as schemas
from student_success_tool import emails

# Configure logging
logging.basicConfig(level=logging.INFO)
logging.getLogger("py4j").setLevel(logging.WARNING)  # Ignore Databricks logger

# Attempt to create a Spark session. Handles exceptions if not in Databricks.
try:
    spark_session = DatabricksSession.builder.getOrCreate()
except Exception:
    logging.warning("Unable to create Spark session; are you in a Databricks runtime?")
    spark_session = None

# Input parameters (provided via Databricks widgets or job task values)
DB_workspace = dbutils.widgets.get("DB_workspace")  # Databricks workspace identifier
notif_email = dbutils.widgets.get("notification_email")
# The institution_name parameter is the databricksified institution name and is NOT the same as the institution id in GCP.
institution_name = dbutils.widgets.get("databricks_institution_name")
course_file_name = dbutils.widgets.get("course_file_name")
cohort_file_name = dbutils.widgets.get("cohort_file_name")
# FYI: the job run id corresponds to the {{parent_run_id}} for jobs with two or more tasks https://stackoverflow.com/questions/75428900/confusion-about-run-id-and-parent-run-id-variables-for-databricks-jobs
db_run_id = dbutils.widgets.get("db_run_id")
# This is the name of the external bucket -- the internal bucket is simply this name with _internal suffixed.
gcp_bucket_name = dbutils.widgets.get("gcp_bucket_name")

# Notify user that an inference run has been kicked off.
w = WorkspaceClient()
MANDRILL_USERNAME = w.dbutils.secrets.get(scope="sst", key="MANDRILL_USERNAME")
MANDRILL_PASSWORD = w.dbutils.secrets.get(scope="sst", key="MANDRILL_PASSWORD")
SENDER_EMAIL = Address("Datakind Info", "help", "datakind.org")
DK_CC_EMAIL = "education@datakind.org"
# NOT WORKING -- can't find the package
# emails.send_inference_kickoff_email(SENDER_EMAIL, [notif_email], [], MANDRILL_USERNAME, MANDRILL_PASSWORD)

# Define paths (using Unity Catalog volumes)
internal_pipeline_path = f"/Volumes/{DB_workspace}/{institution_name}_bronze/bronze_volume/inference_jobs/{db_run_id}/raw_files/"


# Create internal pipeline directory
os.makedirs(internal_pipeline_path, exist_ok=True)

# Initialize GCS client
storage_client = storage.Client()
bucket = storage_client.bucket(gcp_bucket_name)
sst_container_folder = "validated"

# Download course data from GCS
course_blob_name = f"{sst_container_folder}/{course_file_name}"
course_blob = bucket.blob(course_blob_name)
course_blob.download_to_filename(f"{internal_pipeline_path}{course_file_name}")

# Download cohort data from GCS
cohort_blob_name = f"{sst_container_folder}/{cohort_file_name}"
cohort_blob = bucket.blob(cohort_blob_name)
cohort_blob.download_to_filename(f"{internal_pipeline_path}{cohort_file_name}")


# Set path_volume (important for compatibility with Datakind's code)
path_volume = internal_pipeline_path

# Construct full file paths
fpath_course = os.path.join(path_volume, course_file_name)
fpath_cohort = os.path.join(path_volume, cohort_file_name)

# Read data from CSV files into Pandas DataFrames and validate schema
df_course = dataio.pdp.read_raw_course_data(
    file_path=fpath_course,
    schema=schemas.RawPDPCourseDataSchema,
    dttm_format="%Y-%m-%d",
)
df_cohort = dataio.pdp.read_raw_cohort_data(
    file_path=fpath_cohort, schema=schemas.RawPDPCohortDataSchema
)


# Define Delta Lake table details
catalog = DB_workspace
write_schema = f"{institution_name}_bronze"


# Write DataFrames to Delta Lake tables (only if Spark session is available)
if spark_session:
    dataio.to_delta_table(
        df_course,
        f"{catalog}.{write_schema}.{db_run_id}_course_dataset_validated",
        spark_session=spark_session,
    )

    dataio.to_delta_table(
        df_cohort,
        f"{catalog}.{write_schema}.{db_run_id}_cohort_dataset_validated",
        spark_session=spark_session,
    )

    # Verify Delta Lake write by reading data back
    df_course_from_catalog = schemas.RawPDPCourseDataSchema(
        dataio.from_delta_table(
            f"{catalog}.{write_schema}.{db_run_id}_course_dataset_validated",
            spark_session=spark_session,
        )
    )
    print(f"Course DataFrame shape from catalog: {df_course_from_catalog.shape}")

    df_cohort_from_catalog = schemas.RawPDPCohortDataSchema(
        dataio.from_delta_table(
            f"{catalog}.{write_schema}.{db_run_id}_cohort_dataset_validated",
            spark_session=spark_session,
        )
    )
    print(f"Cohort DataFrame shape from catalog: {df_cohort_from_catalog.shape}")
else:
    logging.warning(
        "Spark session not initialized. Skipping Delta Lake write and verification."
    )