In [0]:
import json, re, time, requests
from collections import deque
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType
from delta.tables import DeltaTable

In [0]:
SILVER_TABLE = "silver.unified.unified_companies"   # your Silver table name
GOLD_PATH = "abfss://gold@singaporecomadls.dfs.core.windows.net/llm_enriched_companies/"
GOLD_DB = "gold"
GOLD_TABLE = "llm_enriched_companies"

In [0]:
NVIDIA_MODEL = "meta/llama-4-maverick-17b-128e-instruct"
NVIDIA_URL = "https://integrate.api.nvidia.com/v1/chat/completions"
NVIDIA_API_KEY = "type your key"

In [0]:
# Batching + rate limiting
BATCH_SIZE = 500               # rows per Spark/Pandas batch
MAX_REQUESTS_PER_MIN = 40      # hard cap
REQUEST_TIMEOUT = 60           # seconds
TEMPERATURE = 0.3
MAX_TOKENS = 400

In [0]:
def build_prompt(company_name, description, industry):
    return f"""
You are a structured business data enrichment assistant.

Given:
Company Name: {company_name}
Industry: {industry}
Description: {description}

Return ONLY a valid JSON object with these fields:
{{
  "keywords": "Comma-separated top 10 relevant business keywords",
  "normalized_industry": "General category like Finance, Technology, Retail, Manufacturing, Healthcare",
  "company_size": "Small, Medium, or Large (based on description/scale)",
  "products_offered": ["List", "of", "key", "products"] or null,
  "services_offered": ["List", "of", "key", "services"] or null
}}
"""

# =========================
# üõ°Ô∏è RATE LIMITER (rolling 60s window)
# =========================
_req_times = deque()  # store timestamps of the last requests

def rate_limit_guard(max_per_min=MAX_REQUESTS_PER_MIN):
    now = time.monotonic()
    # purge timestamps older than 60s
    while _req_times and (now - _req_times[0] > 60):
        _req_times.popleft()
    if len(_req_times) >= max_per_min:
        # sleep until we fall under the window
        sleep_for = 60 - (now - _req_times[0]) + 0.01
        if sleep_for > 0:
            time.sleep(sleep_for)
    # record this request
    _req_times.append(time.monotonic())

# =========================
# üîå NVIDIA CALL (chat/completions)
# =========================
def call_nvidia_enrichment(prompt_text):
    """
    Calls NVIDIA NIM Llama-4 Maverick; returns a JSON-compatible dict.
    Raises RuntimeError('QuotaExhausted') on 429 to stop immediately.
    """

    headers = {
        "Authorization": f"Bearer {NVIDIA_API_KEY}",
        "Content-Type": "application/json",
        "Accept": "application/json",
    }

    payload = {
        "model": NVIDIA_MODEL,
        "messages": [
            {
                "role": "system",
                "content": (
                    "You are a strict JSON responder. "
                    "Return ONLY a valid JSON object. No markdown or explanations."
                ),
            },
            {"role": "user", "content": prompt_text},
        ],
        "temperature": TEMPERATURE,
        "max_tokens": MAX_TOKENS,
        "stream": False,
    }

    # Respect rolling 60s rate limit
    rate_limit_guard()

    try:
        resp = requests.post(NVIDIA_URL, headers=headers, json=payload, timeout=REQUEST_TIMEOUT)
        if resp.status_code == 429:
            # stop immediately (user wants to preserve progress)
            raise RuntimeError("QuotaExhausted")
        if resp.status_code == 401:
            raise RuntimeError("Unauthorized ‚Äî check NVIDIA API key.")
        if not resp.ok:
            raise RuntimeError(f"API Error: {resp.status_code} - {resp.text}")

        content = resp.json()["choices"][0]["message"]["content"].strip()

        # üßπ Extract JSON portion robustly
        match = re.search(r"\{[\s\S]*\}", content)
        if not match:
            # Fallback: put entire content in keywords string
            return {
                "keywords": content,
                "normalized_industry": None,
                "company_size": None,
                "products_offered": None,
                "services_offered": None,
            }

        json_str = match.group(0)
        try:
            return json.loads(json_str)
        except json.JSONDecodeError:
            return {
                "keywords": content,
                "normalized_industry": None,
                "company_size": None,
                "products_offered": None,
                "services_offered": None,
            }

    except Exception as e:
        # Keep processing; attach error so schema remains stable
        if "QuotaExhausted" in str(e):
            raise
        return {
            "keywords": None,
            "normalized_industry": None,
            "company_size": None,
            "products_offered": None,
            "services_offered": None,
            "error": str(e),
        }

In [0]:
df_silver = spark.table(SILVER_TABLE)
df_filtered = df_silver.filter(
    (F.col("company_description").isNotNull()) &
    (F.length(F.trim(F.col("company_description"))) > 0)
)

print(f"Records eligible for enrichment: {df_filtered.count()}")

pdf = df_filtered.select("uen", "company_name", "company_description", "industry").toPandas()


In [0]:
expected_keys = [
    "uen",
    "keywords",
    "normalized_industry",
    "company_size",
    "products_offered",
    "services_offered",
    "error",
    "source_of_data",
]
schema = StructType([
    StructField("uen", StringType(), True),
    StructField("keywords", StringType(), True),
    StructField("normalized_industry", StringType(), True),
    StructField("company_size", StringType(), True),
    StructField("products_offered", StringType(), True),
    StructField("services_offered", StringType(), True),
    StructField("error", StringType(), True),
    StructField("source_of_data", StringType(), True),
])

In [0]:
total = len(pdf)
quota_exhausted = False

for start in range(0, total, BATCH_SIZE):
    batch = pdf.iloc[start:start + BATCH_SIZE]
    batch_results = []

    print(f"üîπ Processing batch {start // BATCH_SIZE + 1} ({len(batch)} rows)")

    for _, row in batch.iterrows():
        name = row.get("company_name")
        desc = row.get("company_description")
        ind  = row.get("industry")

        # Skip empty rows (save quota)
        if not (desc or name):
            continue

        prompt = build_prompt(name, desc, ind)

        try:
            out = call_nvidia_enrichment(prompt)
            out["uen"] = row.get("uen")
            batch_results.append(out)

        except RuntimeError as qe:
            if "QuotaExhausted" in str(qe):
                print("üö® NVIDIA rate limit reached ‚Äî stopping to preserve progress.")
                quota_exhausted = True
                break
        except Exception as e:
            batch_results.append({"uen": row.get("uen"), "error": str(e)})

    if quota_exhausted:
        # stop processing more batches, write what we already have (below)
        break
    
    for r in batch_results:
        for key in expected_keys:
            r.setdefault(key, None)
    # Skip empty batch safely
    if not batch_results:
        print(f"‚ö†Ô∏è No results for batch {start // BATCH_SIZE + 1}, skipping write.")
        if quota_exhausted:
            break
        continue

    # Normalize to flat rows with consistent keys
    normalized_batch = []
    for r in batch_results:
        if not isinstance(r, dict):
            continue
        # Make every expected key present
        flat = {k: r.get(k, None) for k in expected_keys}
        # Fill source_of_data
        if flat.get("source_of_data") is None:
            flat["source_of_data"] = "LLM (NVIDIA Llama-4 Maverick 17B)"
        # Ensure list fields stored as comma-separated strings (Delta StringType)
        for list_key in ["products_offered", "services_offered"]:
            if isinstance(flat.get(list_key), list):
                flat[list_key] = ", ".join([str(x) for x in flat[list_key]])
        # keywords might be comma-separated already ‚Äî leave as-is
        normalized_batch.append(flat)

    # If every row is just an error, skip writing to avoid empty Delta schema
    if not normalized_batch or all(nb.get("error") for nb in normalized_batch):
        print(f"‚ö†Ô∏è Batch {start // BATCH_SIZE + 1} had no usable results, skipping write.")
        if quota_exhausted:
            break
        continue

    # Create Spark DF with explicit schema and write
    batch_sdf = spark.createDataFrame(normalized_batch, schema=schema)
    batch_sdf = (
        batch_sdf
        .withColumnRenamed("normalized_industry", "llm_normalized_industry")
        .withColumn("created_at", F.current_timestamp())
        .withColumn("updated_at", F.current_timestamp())
    )

    (
        batch_sdf.write
        .format("delta")
        .mode("append")
        .option("mergeSchema", "true")
        .save(GOLD_PATH)
    )

    print(f"‚úÖ Batch {start // BATCH_SIZE + 1} written ({len(normalized_batch)} records).")

    if quota_exhausted:
        break

print("üèÅ Enrichment complete (stopped early if rate limit hit).")

# =========================
# üìä POST-RUN SUMMARY
# =========================
if DeltaTable.isDeltaTable(spark, GOLD_PATH):
    total_written = spark.read.format("delta").load(GOLD_PATH).count()
    print(f"üìä Total enriched records written to Gold: {total_written}")
else:
    print("‚ö†Ô∏è No Delta table created at GOLD_PATH (all batches empty or stopped too early).")

In [0]:
display(spark.read.format("delta").load(GOLD_PATH))

In [0]:
# =========================
# üóÇÔ∏è REGISTER GOLD TABLE
# =========================
spark.sql(f"CREATE DATABASE IF NOT EXISTS {GOLD_DB}")
spark.sql(f"""
CREATE TABLE IF NOT EXISTS {GOLD_DB}.final.{GOLD_TABLE}
USING DELTA
LOCATION 'abfss://gold@singaporecomadls.dfs.core.windows.net/llm_enriched_companies/'
""")

In [0]:
# =========================
# üóÇÔ∏è REGISTER GOLD TABLE
# =========================

spark.sql(f"""
CREATE TABLE IF NOT EXISTS gold.final.llm_companies
USING DELTA
LOCATION 'abfss://gold@singaporecomadls.dfs.core.windows.net/llm_enriched_companies/'
""")

In [0]:
%sql
SELECT * from  gold.final.llm_enriched_companies


In [0]:
%sql
-- ALTER TABLE gold.final.llm_companies RENAME TO gold.final.llm_enriched_companies