# LLM Mapping Notebook Template

Use this scaffold to normalize free-form values into canonical targets with `map_column_with_llm`.
It includes logging, dry-run safety checks, and an optional persistence step.


## Notebook guidelines

- Name notebooks in clear snake_case and keep one mapping target per notebook.
- Start in dry-run mode to estimate coverage and cost before enabling live LLM calls.
- Keep target values short and stable; treat them as the canonical vocabulary.
- Avoid logging secrets; rely on environment variables or secret scopes.
- Make runs idempotent with deterministic transforms and safe overwrite logic.


In [None]:
from app.config.settings import get_settings

import datetime as _dt

from pyspark.sql import Row, functions as F

from spark_fuse.spark import create_session
from spark_fuse.utils.dataframe import ensure_columns, preview
from spark_fuse.utils.progress import (
    console,
    create_progress_tracker,
    enable_spark_logging,
    log_end as log_end_step,
    log_error as log_error_step,
    log_info as log_info_step,
    log_warn as log_warn_step,
)
from spark_fuse.utils.llm import map_column_with_llm

progress_tracker = create_progress_tracker(total_steps=9)
log = console()
SHOW_HTML = True  # Render HTML log cards in notebooks by default.

def log_info(label: str, *, advance: int = 1, show_html: bool = SHOW_HTML) -> None:
    log_info_step(progress_tracker, log, label, advance=advance, show_html=show_html)


def log_warn(label: str, *, advance: int = 1, show_html: bool = SHOW_HTML) -> None:
    log_warn_step(progress_tracker, log, label, advance=advance, show_html=show_html)


def log_error(label: str, *, advance: int = 1, show_html: bool = SHOW_HTML) -> None:
    log_error_step(progress_tracker, log, label, advance=advance, show_html=show_html)


def log_end(label: str, *, advance: int = 1, show_html: bool = SHOW_HTML) -> None:
    log_end_step(progress_tracker, log, label, advance=advance, show_html=show_html)


def show_df(df, n: int = 5) -> None:
    if "display" in globals():
        display(df)  # type: ignore[name-defined]
    else:
        print(preview(df, n=n))

# Set any reusable parameters here
job_ts = _dt.datetime.now().replace(microsecond=0).isoformat()
TARGET_COLUMN = "company"
TARGET_VALUES = ["OpenAI", "Alphabet", "Microsoft", "Amazon"]
MAPPING_MODEL = "o4-mini"
RUN_LIVE_MAPPING = False
WRITE_OUTPUT = False


## Load configuration

Settings are layered in priority order (config/base.yaml -> config/{env}.yaml -> .env -> env vars).
Set APP_ENV to local, staging, or prod, and keep secrets in .env or real environment variables.


In [None]:
settings = get_settings()
log_info(f"Loaded settings for env={settings.env}")


## Configure LLM credentials

Enter credentials below. On Databricks, these widgets populate environment variables used by the mapping helper.
Optionally, provide a secret scope and key names to resolve values from Databricks Secrets.


In [None]:
import os

if "dbutils" in globals():
    dbutils.widgets.removeAll()
    dbutils.widgets.text("OPENAI_API_KEY", "", "OpenAI API Key")
    dbutils.widgets.text("AZURE_OPENAI_ENDPOINT", "", "Azure OpenAI Endpoint (optional)")
    dbutils.widgets.text("AZURE_OPENAI_KEY", "", "Azure OpenAI Key (optional)")
    dbutils.widgets.text("AZURE_OPENAI_API_VERSION", "2023-05-15", "Azure OpenAI API Version")
    dbutils.widgets.text("LLM_SECRET_SCOPE", "", "Secret Scope (optional)")
    dbutils.widgets.text("SECRET_OPENAI_API_KEY", "", "Secret Key: OpenAI API Key")
    dbutils.widgets.text("SECRET_AZURE_ENDPOINT", "", "Secret Key: Azure Endpoint")
    dbutils.widgets.text("SECRET_AZURE_API_KEY", "", "Secret Key: Azure API Key")
    dbutils.widgets.text("SECRET_AZURE_API_VERSION", "", "Secret Key: Azure API Version")

    def _widget(name: str) -> str:
        return dbutils.widgets.get(name).strip()

    scope = _widget("LLM_SECRET_SCOPE")

    def _resolve(widget_name: str, secret_widget: str) -> str:
        value = _widget(widget_name)
        secret_name = _widget(secret_widget) if scope else ""
        if scope and secret_name:
            try:
                secret_value = dbutils.secrets.get(scope=scope, key=secret_name)
                if secret_value:
                    value = secret_value
            except Exception as exc:  # noqa: BLE001
                print(f"Warning: unable to read secret '{secret_name}' from scope '{scope}': {exc}")
        return value

    openai_key = _resolve("OPENAI_API_KEY", "SECRET_OPENAI_API_KEY")
    azure_endpoint = _resolve("AZURE_OPENAI_ENDPOINT", "SECRET_AZURE_ENDPOINT")
    azure_key = _resolve("AZURE_OPENAI_KEY", "SECRET_AZURE_API_KEY")
    azure_version = _resolve("AZURE_OPENAI_API_VERSION", "SECRET_AZURE_API_VERSION") or "2023-05-15"
else:
    openai_key = os.environ.get("OPENAI_API_KEY", "").strip()
    azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT", "").strip()
    azure_key = os.environ.get("AZURE_OPENAI_KEY", "").strip()
    azure_version = os.environ.get("AZURE_OPENAI_API_VERSION", "2023-05-15").strip()

if openai_key:
    os.environ["OPENAI_API_KEY"] = openai_key
if azure_endpoint:
    os.environ["AZURE_OPENAI_ENDPOINT"] = azure_endpoint
if azure_key:
    os.environ["AZURE_OPENAI_KEY"] = azure_key
if azure_version:
    os.environ["AZURE_OPENAI_API_VERSION"] = azure_version

provider = "azure" if azure_endpoint else "openai"
if openai_key or azure_key:
    log_info(f"LLM credentials configured for {provider}.", advance=0)
else:
    log_warn("No LLM credentials found; live mapping will fail without API keys.", advance=0)


## Create a session

Adjust `app_name`, `master`, and configs for your environment.


In [None]:
log_info("Starting Spark session...", advance=0)
spark = create_session(
    app_name=f"llm-mapping-template-{settings.env}",
    master="local[*]",
    extra_configs={"spark.some.credential": "value"},
)
log_info("Spark session ready")
spark


## Start logging

Raise Spark log verbosity while you iterate so shuffle and scheduler details show up in the driver logs.


In [None]:
enable_spark_logging(spark, level=settings.logging.level)
log_info(f"Spark logging enabled at {settings.logging.level}.", advance=0)
log_info("Logging configured")


## Load relevant data

Declare input locations and load dataframes. Replace the sample with your sources.


In [None]:
log_info("Loading input data (dummy samples; replace with real sources)...", advance=0)

sample_data = [
    Row(id=1, company="OpenAI Inc."),
    Row(id=2, company="Alphabet"),
    Row(id=3, company="Micro Soft"),
    Row(id=4, company="amazon.com"),
    Row(id=5, company=None),
]

source_df = spark.createDataFrame(sample_data)
ensure_columns(source_df, [TARGET_COLUMN])

log_info("Input data loaded")
show_df(source_df)


## Run dry-run mapping

Dry-run mode performs exact case-insensitive matching without calling an LLM.


In [None]:
log_info("Running dry-run mapping...", advance=0)

dry_run_df = map_column_with_llm(
    source_df,
    column=TARGET_COLUMN,
    target_values=TARGET_VALUES,
    dry_run=True,
)

show_df(dry_run_df)
log_info("Dry-run mapping complete")


## Run live mapping

Set `RUN_LIVE_MAPPING = True` to call the LLM for fuzzy matching.


In [None]:
if RUN_LIVE_MAPPING:
    if not (openai_key or azure_key):
        mapped_df = None
        log_warn("RUN_LIVE_MAPPING=True but no API key found; skipping live mapping.")
    else:
        log_info("Running live LLM mapping...", advance=0)
        mapped_df = map_column_with_llm(
            source_df,
            column=TARGET_COLUMN,
            target_values=TARGET_VALUES,
            model=MAPPING_MODEL,
            dry_run=False,
            temperature=None,
        )
        show_df(mapped_df)
        log_info("Live mapping complete")
else:
    mapped_df = None
    log_warn("RUN_LIVE_MAPPING is False; skipping live LLM calls.")


## Review mapping metrics


In [None]:
output_df = mapped_df if mapped_df is not None else dry_run_df
ensure_columns(output_df, [TARGET_COLUMN, f"{TARGET_COLUMN}_mapped"])

mapped_count = output_df.filter(F.col(f"{TARGET_COLUMN}_mapped").isNotNull()).count()
unmapped_count = output_df.filter(F.col(f"{TARGET_COLUMN}_mapped").isNull()).count()

log_info(f"Mapped rows: {mapped_count}", advance=0)
log_info(f"Unmapped rows: {unmapped_count}", advance=0)
log_info("Mapping metrics computed")


## Optional write

Persist mapped results to storage if needed.


In [None]:
output_path = f"/tmp/spark_fuse/{settings.env}/llm_mapped_{TARGET_COLUMN}"
output_format = "delta"

def _write_output(df):
    if output_format == "delta":
        df.write.format("delta").mode("overwrite").save(output_path)
    else:
        df.write.format(output_format).mode("overwrite").save(output_path)

if WRITE_OUTPUT:
    try:
        from delta.tables import DeltaTable  # noqa: F401
    except Exception:
        output_format = "parquet"
        log_warn("Delta Lake not available; falling back to parquet for write.", advance=0)

    log_info(f"Writing mapped data to {output_path} (format={output_format})...", advance=0)
    _write_output(output_df)

    persisted_df = spark.read.format(output_format).load(output_path)
    ensure_columns(persisted_df, [TARGET_COLUMN, f"{TARGET_COLUMN}_mapped"])
    assert persisted_df.count() > 0, "Persisted dataset is empty"

    log_info("Write complete")
else:
    log_warn("WRITE_OUTPUT is False; skipping persistence step.")


## Clean up widgets


In [None]:
if "dbutils" in globals():
    dbutils.widgets.removeAll()


## Stop session

Shut down the session once the job completes.


In [None]:
log_info("Stopping Spark session.", advance=0)
spark.stop()
log_end("Spark session stopped")
