In [99]:
import re
import requests
import json
from pyspark.sql.functions import col, explode, size, lit, array, struct, udf
from pyspark.sql.types import ArrayType, StructType, StructField, StringType, IntegerType
from pyspark.sql import SparkSession
from pyspark.sql.functions import monotonically_increasing_id
import time
from pyspark.sql import Row

StatementMeta(, 75ce30fc-c90e-42d0-8531-2b5f13e08d35, 33, Finished, Available, Finished)

In [100]:
# Initialize Spark session (already available in Fabric/Synapse, but good practice)
spark = SparkSession.builder.appName("KEGGDrugEnrichment").getOrCreate()

# Set a base URL for the KEGG API
KEGG_BASE_URL = "https://rest.kegg.jp/get/dr:"

StatementMeta(, 75ce30fc-c90e-42d0-8531-2b5f13e08d35, 34, Finished, Available, Finished)

In [101]:
def parse_kegg_response(text: str) -> dict:
    """
    Parses the plain text response from the KEGG API to extract Diseases and Targets.
    
    Args:
        text: The raw text response from the KEGG 'get' endpoint.
        
    Returns:
        A dictionary containing lists of extracted diseases and targets.
    """
    diseases = []
    targets = []
    names = []
    current_key = None
    
    # Split the response into lines and iterate
    lines = text.strip().split('\n')
    
    # Use a dictionary to store extracted sections for easier parsing
    sections = {}
    
    # 1. Collect lines into sections based on the starting key
    current_section_key = None
    for line in lines:
        if not line:
            continue
            
        # Check if the line starts a new section (e.g., "ENTRY", "DISEASE", "TARGET")
        # FIX: Use r'^\s*([A-Z_]+)\s+' to handle optional leading whitespace before the key
        match_key = re.match(r'^\s*([A-Z_]+)\s+', line)
        
        if match_key:
            current_section_key = match_key.group(1)
            
            # Calculate the starting position of the content after the key and surrounding whitespace
            # We find the index of the key and then add its length plus one for the following space
            key_start_index = line.find(current_section_key)
            content_start_index = key_start_index + len(current_section_key)
            
            # Store the line content, ensuring to strip the key and any leading/trailing whitespace
            content = line[content_start_index:].strip()
            sections[current_section_key] = [content]
        elif current_section_key and line.startswith(' '):
            # Continuation line (indented)
            sections[current_section_key].append(line.strip())

    # 2. Parse DISEASE section
    if 'NAME' in sections:
        name_lines = sections['NAME']
        temp_name = {}
        smallest_index = 1000
        for index, line in enumerate(name_lines):
            # Expected format: Disease Name [DS:Hxxxxx]
            # Use regex to robustly find the name and the DS ID
            match = re.search(r'(.+)', line)
            if match:
                name = match.group(1).strip()
                name_property = []

                if index < smallest_index:
                    smallest_index = index
                    name_property.append("preferred name")

                if ";" in name:
                    name = re.sub(";", "", name)
                if "(" in name and ")" in name:
                    name_property_search = re.findall(r'\((.+?)\)', name)
                    if name_property_search:
                        name_property.extend(name_property_search)

                    name = re.sub("\(.+?\)", "", name)
                name = name.strip()
                #ds_id = match.group(2).strip()
                if name not in temp_name:
                    temp_name[name] = name_property
                else:
                    temp_name[name].extend(name_property)
    for n in temp_name:
        names.append({"name": n, "property": temp_name[n]})


    # 2. Parse DISEASE section
    if 'DISEASE' in sections:
        disease_lines = sections['DISEASE']
        for line in disease_lines:
            # Expected format: Disease Name [DS:Hxxxxx]
            # Use regex to robustly find the name and the DS ID
            match = re.search(r'(.+?)\s+\[DS:(H\d+)\]', line)
            if match:
                name = match.group(1).strip()
                ds_id = match.group(2).strip()
                diseases.append({"ds_id": ds_id, "name": name})

    # 3. Parse TARGET section
    if 'TARGET' in sections:
        target_lines = sections['TARGET']
        for line in target_lines:
            # 1. Extract the name (before the first bracket)
            match_name = re.match(r'(.+?)\s+\[', line)
            
            name = None
            if match_name:
                name_part = match_name.group(1).strip()
                # remove content in parentheses
                name = re.sub(r'\s*\([^)]*\)\s*', ' ', name_part).strip() 
            
            # 2. Extract the entire KO IDs block
            # Capture content between [KO: and ]
            match_ko_block = re.search(r'\[KO:(.+?)\]', line) 
            
            if name and match_ko_block:
                # The captured group contains one or more K IDs separated by spaces
                ko_id_string = match_ko_block.group(1).strip()
                
                # Split the string by whitespace to get individual KO IDs
                # Filter to ensure we only process valid 'K' identifiers
                ko_ids = [k.strip() for k in ko_id_string.split() if k.strip().startswith('K')]
                
                # Create a target entry for each individual KO ID found
                for ko_id in ko_ids:
                    targets.append({"ko_id": ko_id, "name": name})

    return {"names": names, "diseases": diseases, "targets": targets}

# --- 2. UDF for API Call and Parsing ---

# Define the schema for the returned structured data
KEGG_SCHEMA = StructType([
    StructField("names", ArrayType(StructType([
        StructField("name", StringType(), True),
        StructField("property", ArrayType(StringType(), True))
    ])), True),
    StructField("diseases", ArrayType(StructType([
        StructField("ds_id", StringType(), True),
        StructField("name", StringType(), True)
    ])), True),
    StructField("targets", ArrayType(StructType([
        StructField("ko_id", StringType(), True),
        StructField("name", StringType(), True)
    ])), True)
])


@udf(KEGG_SCHEMA)
def get_kegg_details_udf(kegg_id: str) -> dict:
    """
    Makes the API call to KEGG and processes the response.
    Returns structured data for diseases and targets.
    Includes basic retry logic.
    """
    if not kegg_id:
        return None
        
    url = f"{KEGG_BASE_URL}{kegg_id}"
    
    # Simple exponential backoff retry logic
    max_retries = 3
    for attempt in range(max_retries):
        try:
            response = requests.get(url, timeout=10)
            response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
            
            # The API response is text, not JSON
            return parse_kegg_response(response.text)
            
        except requests.exceptions.HTTPError as e:
            if response.status_code == 404:
                # Log the issue for a specific KEGG ID and return None
                print(f"KEGG ID not found (404) for {kegg_id}. Skipping. Error: {e}")
                return None
            
            print(f"HTTP Error for {kegg_id} on attempt {attempt + 1}: {e}")
            if attempt < max_retries - 1:
                import time
                time.sleep(2 ** attempt) # Wait 1s, 2s
            else:
                print(f"Failed to fetch KEGG details for {kegg_id} after {max_retries} attempts.")
                return None
        except requests.exceptions.RequestException as e:
            print(f"Request Error for {kegg_id} on attempt {attempt + 1}: {e}")
            if attempt < max_retries - 1:
                import time
                time.sleep(2 ** attempt)
            else:
                print(f"Failed to fetch KEGG details for {kegg_id} after {max_retries} attempts.")
                return None
    return None

StatementMeta(, 75ce30fc-c90e-42d0-8531-2b5f13e08d35, 35, Finished, Available, Finished)

In [102]:
def get_kegg_details_test(kegg_id: str) -> dict:
    import requests
    import time # Needed because the UDF uses time.sleep

    
    if not kegg_id:
        return None
        
    url = f"{KEGG_BASE_URL}{kegg_id}"
    
    max_retries = 3
    for attempt in range(max_retries):
        try:
            response = requests.get(url, timeout=10)
            response.raise_for_status() 
            return parse_kegg_response(response.text)
            
        except Exception as e:
            print(f"Test Error on attempt {attempt + 1}: {e}")
            if attempt < max_retries - 1:
                time.sleep(2 ** attempt)
            else:
                return None
    return None

# Execute the test function
#test_id = "D10574"
test_id = "D01531"
test_output = get_kegg_details_test(test_id)
test_output

StatementMeta(, 75ce30fc-c90e-42d0-8531-2b5f13e08d35, 36, Finished, Available, Finished)

{'names': [{'name': 'Pipemidic acid hydrate',
   'property': ['preferred name', 'JP18']},
  {'name': 'Karunomazin', 'property': ['TN']}],
 'diseases': [],
 'targets': [{'ko_id': 'K02469', 'name': 'DNA gyrase'},
  {'ko_id': 'K02470', 'name': 'DNA gyrase'}]}

In [103]:

source_df = spark.read.format("delta").table("SilverLakeHouse.atc_codes")

filtered_df = source_df.filter(
    (col("depth") == "E") &
    (col("KEGG").isNotNull()) &
    (size(col("KEGG")) > 0)
)

print(f"Found {filtered_df.count()} ATC level E codes with KEGG drugs.")

StatementMeta(, 75ce30fc-c90e-42d0-8531-2b5f13e08d35, 37, Finished, Available, Finished)

Found 4595 ATC level E codes with KEGG drugs.


In [104]:

# 3.3 Explode the KEGG list and extract the KEGG drug ID
drug_df = filtered_df.withColumn("kegg_drug_info", explode(col("KEGG"))) \
                     .withColumn("drug_kegg_id", col("kegg_drug_info.kegg_id")) \
                     .filter(col("drug_kegg_id").startswith("D")) \
                     .select(
                         col("atc"),
                         col("name"),
                         col("drug_kegg_id").alias("drug_kegg_id")
                     ).distinct()

print(f"Total unique drugs to process: {drug_df.count()}")

StatementMeta(, 75ce30fc-c90e-42d0-8531-2b5f13e08d35, 38, Finished, Available, Finished)

Total unique drugs to process: 7751


In [105]:
drug_df.head(3)

StatementMeta(, 75ce30fc-c90e-42d0-8531-2b5f13e08d35, 39, Finished, Available, Finished)

[Row(atc='G03FB09', name='Levonorgestrel and estrogen', drug_kegg_id='D04482'),
 Row(atc='H02AB11', name='Prednylidene', drug_kegg_id='D07230'),
 Row(atc='H02AB12', name='Rimexolone', drug_kegg_id='D05729')]

In [106]:
# New cell to handle sequential API calls on the driver


def fetch_kegg_data_sequentially(drug_df, kegg_base_url, rate_limit_delay=0.4):
    """
    Collects KEGG IDs to the driver, calls the KEGG API sequentially
    with a delay to respect rate limits, and returns a list of Rows.
    """
    
    # 1. Collect unique KEGG IDs to the driver
    kegg_ids_to_process = drug_df.select("drug_kegg_id").distinct().rdd.map(lambda row: row[0]).collect()
    total_drugs = len(kegg_ids_to_process)
    print(f"Starting sequential fetching for {total_drugs} unique KEGG IDs...")
    
    results = []
    
    # 2. Sequential API Call Loop
    start_time = time.time()
    for i, kegg_id in enumerate(kegg_ids_to_process):
        # Apply the API call logic from your original UDF
        url = f"{kegg_base_url}{kegg_id}"
        max_retries = 3
        kegg_details = None
        
        for attempt in range(max_retries):
            try:
                response = requests.get(url, timeout=10)
                response.raise_for_status()
                kegg_details = parse_kegg_response(response.text)
                
                # Success: Append the result and break retry loop
                results.append(Row(drug_kegg_id=kegg_id, kegg_details=kegg_details))
                break 

            except requests.exceptions.HTTPError as e:
                # 404 is a permanent error, no need to retry
                if response.status_code == 404:
                    print(f"KEGG ID not found (404) for {kegg_id}. Skipping.")
                    break
                
                # Other HTTP errors (including 429 Rate Limit)
                print(f"HTTP Error for {kegg_id} on attempt {attempt + 1}. Status: {response.status_code}")
                if attempt < max_retries - 1:
                    time.sleep(2 ** attempt)
                else:
                    print(f"Failed to fetch {kegg_id} after {max_retries} attempts.")
                    break
            except requests.exceptions.RequestException as e:
                print(f"Request Error for {kegg_id} on attempt {attempt + 1}: {e}")
                if attempt < max_retries - 1:
                    time.sleep(2 ** attempt)
                else:
                    print(f"Failed to fetch {kegg_id} after {max_retries} attempts.")
                    break

        # CRITICAL: Rate Limiting Delay
        time.sleep(rate_limit_delay)
        
        # Checking Logic 1: Print progress
        if (i + 1) % 500 == 0 or (i + 1) == total_drugs:
            elapsed_time = time.time() - start_time
            progress = (i + 1) / total_drugs * 100
            print(f"Progress: {i + 1}/{total_drugs} ({progress:.2f}%) | Time Elapsed: {elapsed_time:.2f}s")
            
    end_time = time.time()
    print(f"\nSequential fetching complete. Total time: {end_time - start_time:.2f} seconds.")
    print(f"Successfully retrieved details for {len(results)} drugs.")

    return results

# 3. Define the Schema (needed for creating the DataFrame)
# Use your existing schema definition from the original cell 101
KEGG_SCHEMA = StructType([
    StructField("names", ArrayType(StructType([
        StructField("name", StringType(), True),
        StructField("property", ArrayType(StringType(), True))
    ])), True),
    StructField("diseases", ArrayType(StructType([
        StructField("ds_id", StringType(), True),
        StructField("name", StringType(), True)
    ])), True),
    StructField("targets", ArrayType(StructType([
        StructField("ko_id", StringType(), True),
        StructField("name", StringType(), True)
    ])), True)
])

ENRICHED_SCHEMA = StructType([
    StructField("drug_kegg_id", StringType(), True),
    StructField("kegg_details", KEGG_SCHEMA, True)
])

# Execute the sequential fetch
kegg_results = fetch_kegg_data_sequentially(
    drug_df, 
    KEGG_BASE_URL, 
    rate_limit_delay=0.4  # Set a safe delay like 0.12s (8.3 calls/sec)
)

# 4. Create the final enriched Spark DataFrame
enriched_df = spark.createDataFrame(kegg_results, schema=ENRICHED_SCHEMA)
print(f"Enriched DataFrame count: {enriched_df.count()}")

StatementMeta(, 75ce30fc-c90e-42d0-8531-2b5f13e08d35, 40, Submitted, Running, Running)

Starting sequential fetching for 5820 unique KEGG IDs...
Progress: 500/5820 (8.59%) | Time Elapsed: 263.90s
Progress: 1000/5820 (17.18%) | Time Elapsed: 528.04s
KEGG ID not found (404) for D08771. Skipping.
Progress: 1500/5820 (25.77%) | Time Elapsed: 790.97s
Progress: 2000/5820 (34.36%) | Time Elapsed: 1054.21s
Progress: 2500/5820 (42.96%) | Time Elapsed: 1317.39s


Progress: 3000/5820 (51.55%) | Time Elapsed: 1580.28s


Progress: 3500/5820 (60.14%) | Time Elapsed: 1843.14s


Progress: 4000/5820 (68.73%) | Time Elapsed: 2105.99s


Progress: 4500/5820 (77.32%) | Time Elapsed: 2368.98s


Progress: 5000/5820 (85.91%) | Time Elapsed: 2631.57s


Progress: 5500/5820 (94.50%) | Time Elapsed: 2894.37s


Progress: 5820/5820 (100.00%) | Time Elapsed: 3063.07s

Sequential fetching complete. Total time: 3063.07 seconds.
Successfully retrieved details for 5819 drugs.
Enriched DataFrame count: 5819


In [107]:
#enriched_df.show()

StatementMeta(, , -1, Waiting, , Waiting)

In [108]:
# 3.5 Explode and structure the results for final tables

# --- Create Drug-Disease Fact Table ---
drug_name_df = enriched_df.withColumn("name", explode(col("kegg_details.names"))) \
                             .select(
                                 col("drug_kegg_id"),
                                 col("name.name").alias("drug_name"),
                                 col("name.property").alias("drug_name_property")
                             )


drug_disease_df = enriched_df.withColumn("disease", explode(col("kegg_details.diseases"))) \
                             .select(
                                 col("drug_kegg_id"),
                                 col("disease.ds_id").alias("disease_ds_id"),
                                 col("disease.name").alias("disease_name")
                             )

# --- Create Drug-Target Fact Table ---
drug_target_df = enriched_df.withColumn("target", explode(col("kegg_details.targets"))) \
                            .select(
                                col("drug_kegg_id"),
                                col("target.ko_id").alias("target_ko_id"),
                                col("target.name").alias("target_name")
                            )

StatementMeta(, , -1, Waiting, , Waiting)

In [109]:
#drug_name_df.show()

StatementMeta(, , -1, Waiting, , Waiting)

In [110]:
#drug_df.head()

StatementMeta(, , -1, Waiting, , Waiting)

In [111]:
# --- Create Dimension Tables (Diseases and Targets) ---

# Create the new table for the ATC-KEGG mapping (dim_drug_atc)
dim_drug_atc_df = drug_df.select(
    col("atc"),
    col("name"),
    col("drug_kegg_id").alias("kegg_id")
)

# Disease Dimension Table
dim_drugs_df = drug_name_df.select(
    col("drug_kegg_id").alias("drug_id"),
    col("drug_name").alias("drug_name"),
    col("drug_name_property").alias("drug_name_property")
).distinct().withColumn("name_id", monotonically_increasing_id())

# Disease Dimension Table
dim_diseases_df = drug_disease_df.select(
    col("disease_ds_id").alias("disease_id"),
    col("disease_name").alias("name")
).distinct().withColumn("sk", monotonically_increasing_id())

# Target Dimension Table
dim_targets_df = drug_target_df.select(
    col("target_ko_id").alias("ko_number"),
    col("target_name").alias("name")
).distinct().withColumn("target_id", monotonically_increasing_id()) # Use a Spark ID for internal target_id

StatementMeta(, , -1, Waiting, , Waiting)

In [112]:

# 3.6 Finalize Fact Tables with natural keys (DS_ID, KO_ID)
#final_fact_drug_name_df = drug_name_df.select(
#    col("drug_kegg_id"),
#    col("drug_name").alias("name")
#).distinct()

final_fact_drug_disease_df = drug_disease_df.select(
    col("drug_kegg_id"),
    col("disease_ds_id").alias("disease_id")
).distinct()

final_fact_drug_target_df = drug_target_df.select(
    col("drug_kegg_id"),
    col("target_ko_id").alias("target_ko_number") # Use KO number as FK for simplicity
).distinct()

StatementMeta(, , -1, Waiting, , Waiting)

In [114]:




# 4. Write all five tables back to the Silver layer

print("Writing Dimension and Fact Tables to Silver...")

# Write Dim Drug-ATC (new table)
(dim_drug_atc_df
    .write
    .format("delta")
    .mode("overwrite")
    .saveAsTable("SilverLakeHouse.dim_drug_atc")
)

# Write Dim Diseases
(dim_drugs_df
    .write
    .format("delta")
    .mode("overwrite")
    .saveAsTable("SilverLakeHouse.dim_drugs")
)

# Write Dim Diseases
(dim_diseases_df
    .write
    .format("delta")
    .mode("overwrite")
    .saveAsTable("SilverLakeHouse.dim_diseases")
)

# Write Dim Targets
(dim_targets_df
    .write
    .format("delta")
    .mode("overwrite")
    .saveAsTable("SilverLakeHouse.dim_targets")
)


# Write Fact Drug-Name
#(final_fact_drug_name_df
#    .write
#    .format("delta")
#    .mode("overwrite")
#    .saveAsTable("SilverLakeHouse.fact_drug")
#)

# Write Fact Drug-Disease
(final_fact_drug_disease_df
    .write
    .format("delta")
    .mode("overwrite")
    .saveAsTable("SilverLakeHouse.fact_drug_disease")
)

# Write Fact Drug-Target
(final_fact_drug_target_df
    .write
    .format("delta")
    .mode("overwrite")
    .saveAsTable("SilverLakeHouse.fact_drug_target")
)

print("KEGG Drug Enrichment process completed.")
print(f"SilverLakeHouse.dim_drug_atc count: {dim_drug_atc_df.count()}")
print(f"SilverLakeHouse.dim_diseases count: {dim_diseases_df.count()}")
print(f"SilverLakeHouse.dim_targets count: {dim_targets_df.count()}")
print(f"SilverLakeHouse.fact_drug_disease count: {final_fact_drug_disease_df.count()}")
print(f"SilverLakeHouse.fact_drug_target count: {final_fact_drug_target_df.count()}")

# Show a sample of the new Dim Drug-ATC table
print("Sample of SilverLakeHouse.dim_drug_atc:")
dim_drug_atc_df.limit(5).show(truncate=False)

StatementMeta(, , -1, Waiting, , Waiting)

Writing Dimension and Fact Tables to Silver...


KEGG Drug Enrichment process completed.
SilverLakeHouse.dim_drug_atc count: 7751
SilverLakeHouse.dim_diseases count: 940
SilverLakeHouse.dim_targets count: 783
SilverLakeHouse.fact_drug_disease count: 2549


SilverLakeHouse.fact_drug_target count: 7422
Sample of SilverLakeHouse.dim_drug_atc:
+-------+---------------------------+-------+
|atc    |name                       |kegg_id|
+-------+---------------------------+-------+
|G03FB09|Levonorgestrel and estrogen|D04482 |
|H02AB11|Prednylidene               |D07230 |
|H02AB12|Rimexolone                 |D05729 |
|H04AA01|Glucagon                   |D00116 |
|J01GB13|Bekanamycin                |D07497 |
+-------+---------------------------+-------+

