In [None]:
# Version
#print('Version 1.0.0: 10/19/2022 11:59am - Nate Calvanese - Initial version')
#print('Version 1.0.1: 10/24/2022 3:18pm - Nate Calvanese - Added pass through of params dict')
#print('Version 1.0.2: 02/07/2024 12:54pm - Nate Calvanese - Added logic to handle part_of_dataset required field')
#print('Version 1.0.3: 03/12/2024 12:12pm - Nate Calvanese - Fixed a bug introduced in V1.0.2 update')
#print('Version 1.0.4: 09/26/2024 9:28am - Nate Calvanese - Improved logic for handling part_of_dataset field')
print('Version 1.0.5: 5/7/2025 4:15pm - Nate Calvanese - Updated to support dev')


In [None]:
#!pip install --upgrade data_repo_client

In [None]:
## imports and environment variables

# Imports
import import_ipynb
import pandas as pd
import os
import re
import json
import data_repo_client
from google.cloud import bigquery
import ingest_pipeline_utilities as utils
import output_data_validation as odv
import logging
from time import sleep

# Configure logging format
logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO)

# workspace environment variables
ws_name = os.environ["WORKSPACE_NAME"]
ws_project = os.environ["WORKSPACE_NAMESPACE"]
ws_bucket = os.environ["WORKSPACE_BUCKET"]
ws_bucket_name = re.sub('^gs://', '', ws_bucket)

# print(f"workspace name = {ws_name}")
# print(f"workspace project = {ws_project}")
# print(f"workspace bucket = {ws_bucket}")
# print(f"workspace bucket name = {ws_bucket_name}")

In [None]:
def ingest_records(params, dataset_id, table, records_dict):
    logging.info("Missing key values found for {}: {} new records to ingest".format(table, str(len(records_dict))))
    logging.info("Submitting dangling foreign key ingest for {}.".format(table))
    
    # Build, submit, and monitor ingest request
    ingest_request = {
        "table": table,
        "profile_id": params["profile_id"],
        "ignore_unknown_values": True,
        "resolve_existing_files": True,
        "updateStrategy": "replace",
        "format": "array",
        "load_tag": "Dangling foreign key ingest for {} in {}".format(table, dataset_id),
        "records": records_dict
    }
    attempt_counter = 0
    while True:
        try:
            api_client = utils.refresh_tdr_api_client(params["tdr_url"])
            datasets_api = data_repo_client.DatasetsApi(api_client=api_client)
            ingest_request_result, job_id = utils.wait_for_tdr_job(datasets_api.ingest_dataset(id=dataset_id, ingest=ingest_request), params["run_env"])
            logging.info("Dangling foreign key ingest for {} succeeded: {}".format(table, str(ingest_request_result)[0:1000]))
            return_str = "Ingest Succeeded"
            status = "Success"
            return return_str, status
        except Exception as e:
            logging.error("Error on dangling foreign key ingest: {}".format(str(e)))
            attempt_counter += 1
            if attempt_counter < 2:
                logging.info("Retrying dangling foreign key ingest (attempt #{})...".format(str(attempt_counter)))
                sleep(10)
                continue
            else:
                logging.error("Maximum number of retries exceeded. Logging error.")
                return_str = "Ingest Failed ({})".format(str(e))
                status = "Error"
                return return_str, status

def identify_dangling_fks(params, table, client, bq_project, bq_schema, field_list, array_field_set, dataset_name):
    for column_entry in field_list:
        if column_entry["table"] == table and column_entry["is_primary_key"] == True and (len(column_entry["joins_from"]) > 0):
            logging.info("Identifying dangling foreign keys for {}...".format(table))
            
            # Construct a CTE that includes all foreign keys that reference the primary key
            table_name = column_entry["table"]
            col_name = column_entry["column"]
            counter = 0
            cte_query = "WITH temp_fks AS ("
            source_col_list = []
            for entry in column_entry["joins_from"]:
                cte_query_segment = ""
                counter += 1
                source_table = entry["table"]
                source_column = entry["column"]
                source_table_col = entry["table"] + "." + entry["column"]
                if counter > 1:
                    cte_query_segment = "UNION ALL "
                if source_table_col in array_field_set:
                    cte_query_segment += "SELECT DISTINCT {tar_col} FROM `{project}.{schema}.{table}` CROSS JOIN UNNEST({src_col}) AS {tar_col}".format(project = bq_project, schema = bq_schema, table = source_table, src_col = source_column, tar_col = col_name)
                else:
                    cte_query_segment += "SELECT DISTINCT {src_col} as {tar_col}  FROM `{project}.{schema}.{table}`".format(project = bq_project, schema = bq_schema, table = source_table, src_col = source_column, tar_col = col_name)
                cte_query = cte_query + cte_query_segment + " "
            cte_query = cte_query + ")"
            
            # Construct the query to identify foreign keys not present in the primary key field
            if table_name in "anvil_donor":
                dataset_select = f", `dsp-data-ingest.transform_resources`.uuid_hash_value('{dataset_name}') AS part_of_dataset_id"
            elif table_name == "anvil_biosample":
                dataset_select = f", [`dsp-data-ingest.transform_resources`.uuid_hash_value('{dataset_name}')] AS part_of_dataset_id"
            else:
                dataset_select = ""
            dangling_fk_query = """{cte}, base AS (SELECT DISTINCT src.{col} AS {col},
                           FROM temp_fks src LEFT JOIN `{project}.{schema}.{table}` tar ON src.{col} = tar.{col}
                           WHERE src.{col} IS NOT NULL AND tar.{col} IS NULL)
                           SELECT {col}{mod} FROM base""".format(cte = cte_query, project = bq_project, schema = bq_schema, table = table_name, col = col_name, mod = dataset_select)
#             print(table_name)
#             print(dangling_fk_query)
            
            # Execute the query and convert results to a dict
            try:
                df = client.query(dangling_fk_query).result().to_dataframe()
                records_json = df.to_json(orient="records") 
                record_dict = json.loads(records_json)
                return record_dict, "Missing {} values".format(str(len(record_dict))), "Success"
            except Exception as e:
                logging.error("Error during query execution: {}".format(str(e)))
                return [], "Error retrieving missing values ({})".format(str(e)), "Error"
    return [], "No foreign keys identified", "Success"  

def resolve_dangling_fks(params, dataset_id, target_schema):
    # Build return log items
    return_log = ""
    fail_count = 0
    return_status = "Success"
    
    # Establish TDR API client and retrieve the schema for the specified dataset
    logging.info("Attempting to identify the TDR object, and collect and parse its schema...")
    api_client = utils.refresh_tdr_api_client(params["tdr_url"])
    full_tdr_schema, bq_project, bq_schema, skip_bq_queries = odv.retrieve_tdr_schema(dataset_id, "dataset", api_client)
    if skip_bq_queries:
        return "Error retrieving BQ project and schema", "Error"
    table_set, array_field_set, field_list, relationship_count = odv.process_tdr_schema(target_schema, "file")
    
    # Retrieve dataset name
    dataset_name = ""
    try:
        datasets_api = data_repo_client.DatasetsApi(api_client=api_client)
        response = datasets_api.retrieve_dataset(id=dataset_id, include=["SCHEMA", "ACCESS_INFORMATION"]).to_dict()
        dataset_name = response["name"]
        dataset_name = re.sub("(_[0-9]+$)", "", dataset_name)
    except Exception as e:
        return "Error retrieving dataset details", "Error"

    # Loop through target schema tables, identify dangling foreign keys, and create new ingest records for them
    logging.info("Attempting to identify and remediate dangling foreign keys...")
    client = bigquery.Client()
    for table in table_set:
        table_log = table + ": "
        records_dict, identify_str, identify_status = identify_dangling_fks(params, table, client, bq_project, bq_schema, field_list, array_field_set, dataset_name)
        table_log = table_log + identify_str
        ingest_status = ""
        if records_dict:
            ingest_str, ingest_status = ingest_records(params, dataset_id, table, records_dict)
            table_log = table_log + " - " + ingest_str
        if return_log == "":
            return_log = table_log
        else:
            return_log = return_log + "; " + table_log
        if identify_status == "Error" or ingest_status == "Error":
            fail_count += 1
    if fail_count > 0:
        return_status = "Error"
    return return_log, return_status


In [None]:
# # Test
# params = {}
# params["profile_id"] = "e0e03e48-5b96-45ec-baa4-8cc1ebf74c61"
# params["run_env"] = "prod"
# params["tdr_url"] = "https://data.terra.bio"
# #dataset_id = "d239dd7b-8d10-4960-aa91-8f8ede641e25" #anvil_donor example
# #dataset_id = "ae4c80c8-a946-49cb-b376-81b4749f3221" #anvil_biosample example
# dataset_id = "6d18aafc-0240-499c-902e-a72a5b98ff0a"
# mapping_target = "anvil"
# from google.cloud import storage
# storage_client = storage.Client()
# bucket = storage_client.get_bucket(ws_bucket_name)
# blob = bucket.blob("ingest_pipeline/output/transformed/{}/{}/schema/mapping_schema_object.json".format(mapping_target, dataset_id))
# target_schema = json.loads(blob.download_as_string(client=None))
# output, status = resolve_dangling_fks(params, dataset_id, target_schema)
