## Anthropic Stress Test Demo
- Task: `01_llm_cost_mapping`
-------------

### Purpose
- Combine the separate web page datasets into a single table.
- Use AI to map the model names to the databricks natively served models.

-----------------------

### Input Parameters
- **Catalog**
  - The desired catalog for results to be written to upon completion.
- **Schema**
  - The desired schema within the output catalog for results to be written.
- **Mapping Model**
  - The model endpoint that will be used in `ai_query` to figure out the relevant endpoint for each model in the pricing dataset.
    - Currently set by default to `databricks-meta-llama-3-3-70b-instruct`


-----------------------

## Setup

In [0]:
from pyspark.sql import functions as f
from stress_test import create_widgets

create_widgets(dbutils, '01_llm_cost_mapping')

--------
## Get data from tables  

In [0]:
# get catalog and schema
catalog = dbutils.widgets.get('catalog')
schema = dbutils.widgets.get('schema')
mapping_model = dbutils.widgets.get('mapping_model_endpoint')

In [0]:
oss_df = spark.sql(f'select * from {catalog}.{schema}.llm_oss_costs')
display(oss_df)

In [0]:
prop_df = spark.sql(f'select * from {catalog}.{schema}.llm_proprietary_costs')
display(prop_df)

--------
## Union datasets together and fill null values

In [0]:
# union datasets together and allow for missing columns between datasets
union_df = oss_df.unionByName(prop_df, allowMissingColumns = True)

# fill nulls for endpoint_type and context_length
union_df = union_df.fillna('Global', subset='endpoint_type')
union_df = union_df.fillna('All Lengths', subset='context_length')
union_df.createOrReplaceTempView('union_df')
display(union_df)

---------------

## Get list of databricks natively served model endpoints

In [0]:
# get distinct list of model endpoint names
served_models = spark.sql('select distinct endpoint_name from system.serving.served_entities')

# filter for endpoints that start with "databricks" (natively served models)
served_models = served_models.filter(f.col('endpoint_name').like('databricks%'))
served_models = served_models.toPandas()['endpoint_name'].to_list()

# concat into a single string to pass to a prompt
served_models = ', '.join(served_models)
print(served_models)

---------
## Use AI to determine relevant model endpoints

In [0]:

# define the prompt that we will pass along to ai_query
prompt = f"""
You are an assistant designed to map the LLM model name to the served model endpoint provided. You will be given a spark dataframe of model names and a comma seperated list of served model endpoints. You will return the served endpoint most likely associated with the model name. If there is no clear match, return "none". Please do not include a summary or any additional text, return only the mapped model name. Thank you!

Example:
  - model: GPT 5.1
  - output: databricks-gpt-5-1
Example:
  - model: Mistral 7B
  - output: none

served_models: {served_models}
"""

# define the sql query that we will run
sql = f"""
SELECT
  model,
  ai_query('{mapping_model}','{prompt}' || model) AS served_endpoint_name,
  provider,
  endpoint_type,
  context_length,
  input_dbu_1m_tokens,
  output_dbu_1m_tokens,
  cache_writes_dbu_1m_tokens,
  cache_reads_dbu_1m_tokens,
  batch_inference_dbu_hour,
  dbu_per_hour_entry_cap,
  dbu_per_hour_scale_cap
FROM
  union_df
"""

In [0]:
# execute the query and save as a table
mapped_df = spark.sql(sql)
mapped_df.write.mode('overwrite').saveAsTable(f'{catalog}.{schema}.llm_cost_mapping')
display(mapped_df)