## Anthropic Stress Test Demo
- Task: `00_llm_proprietary`
-------------

### Purpose
- Pull Databricks' rates for DBU generation per 1 million tokens of usage to get accurate model pricing.

----------------------
### Links:
- https://www.databricks.com/product/pricing/proprietary-foundation-model-serving

-----------------------

### Input Parameters
- **Catalog**
  - The desired catalog for results to be written to upon completion.
- **Schema**
  - The desired schema within the output catalog for results to be written.

-----------------------

## Setup

In [0]:
import requests
from bs4 import BeautifulSoup
from pyspark.sql import functions as f
from pyspark.sql.types import FloatType
from stress_test import create_widgets, get_indices

create_widgets(dbutils, '00_llm_proprietary')

-----------------------

## Get Webpage Data

In [0]:
# get param values
catalog = dbutils.widgets.get('catalog')
schema = dbutils.widgets.get('schema')

# url for proprietary model pricing on databricks website
url = "https://www.databricks.com/product/pricing/proprietary-foundation-model-serving"
headers = {"User-Agent": "Mozilla/5.0"}

# get the html response
resp = requests.get(url, headers=headers)
resp.raise_for_status()
soup = BeautifulSoup(resp.text, "html.parser")

# Find the header, then the table
header = soup.find(
  lambda tag: tag.name in ["h2", "h3"]
  and "Proprietary Foundation Model Serving DBU rates" in tag.get_text()
  )
table = header.find_next("table")

---------------
## Clean up and navigate data

In [0]:
# find all html lines with the tag "tr"
records = table.find_all('tr')

# get all of our rows and indices
rows = get_indices(records,'td')

# get a list of x indices and sort them
x_indices = list(set(row['x_index'] for row in rows))
x_indices.sort()

# init a list of empty rows
output_rows = []

# loop through each x index
for i in x_indices:

  # get all the records in an index
  children = list(filter(lambda x: x['x_index'] == i, rows))

  # check for size of index
  # if size = 1, this only contains the model provider
  if len(children) == 1:
    provider = children[0]['text']

  # if size = 8, this has all the columns we need for the row
  elif len(children) == 8:
    model = children[0]['text']
    endpoint_type = children[1]['text']
    context_length = children[2]['text']
    input_dbu_1m_tokens = children[3]['text']
    output_dbu_1m_tokens = children[4]['text']
    cache_writes_dbu_1m_tokens = children[5]['text']
    cache_reads_dbu_1m_tokens = children[6]['text']
    batch_inference_dbu_hour = children[7]['text']
  
  # if size = 7, this is a different endpoint_type and context length for a model (e.g. Claude Sonnet 4.5)
  elif len(children) == 7:
    endpoint_type = children[0]['text']
    context_length = children[1]['text']
    input_dbu_1m_tokens = children[2]['text']
    output_dbu_1m_tokens = children[3]['text']
    cache_writes_dbu_1m_tokens = children[4]['text']
    cache_reads_dbu_1m_tokens = children[5]['text']
    batch_inference_dbu_hour = children[6]['text']

  # if size = 6, this means we have either a different context window or a different endpoint type depending on the provider
  elif len(children) == 6:
    if provider == 'Google' or model == 'Claude Sonnet 3.7 / 4 / 4.1':
      context_length = children[0]['text']
    else:
      endpoint_type = children[0]['text']
    input_dbu_1m_tokens = children[1]['text']
    output_dbu_1m_tokens = children[2]['text']
    cache_writes_dbu_1m_tokens = children[3]['text']
    cache_reads_dbu_1m_tokens = children[4]['text']
    batch_inference_dbu_hour = children[5]['text']

  # if size = 7, this is a different endpoint_type and context length for a model (e.g. Claude Sonnet 4.5)
  if len(children) > 1:
    output_rows.append({
      'provider': provider,
      'model': model,
      'endpoint_type': endpoint_type,
      'context_length': context_length,
      'input_dbu_1m_tokens': input_dbu_1m_tokens,
      'output_dbu_1m_tokens': output_dbu_1m_tokens,
      'cache_writes_dbu_1m_tokens': cache_writes_dbu_1m_tokens,
      'cache_reads_dbu_1m_tokens': cache_reads_dbu_1m_tokens,
      'batch_inference_dbu_hour': batch_inference_dbu_hour
    })

-----------
## Create location within Unity Catalog

In [0]:
sql = f'CREATE CATALOG IF NOT EXISTS {catalog}'
spark.sql(sql)

In [0]:
sql = f'CREATE SCHEMA IF NOT EXISTS {catalog}.{schema}'
spark.sql(sql)

In [0]:
# define the order of the columns that we want our final dataframe to be
select_columns = [
  'provider',
  'model',
  'endpoint_type',
  'context_length',
  'input_dbu_1m_tokens',
  'output_dbu_1m_tokens',
  'cache_writes_dbu_1m_tokens',
  'cache_reads_dbu_1m_tokens',
  'batch_inference_dbu_hour',
]
cost_df = spark.createDataFrame(output_rows)
cost_df = cost_df.select(select_columns)
cost_df = cost_df.na.replace('n/a', None)

# define a list of columns to convert to a float
to_float_cols = [
  'input_dbu_1m_tokens',
  'output_dbu_1m_tokens',
  'cache_writes_dbu_1m_tokens',
  'cache_reads_dbu_1m_tokens',
  'batch_inference_dbu_hour', 
]
for col in to_float_cols:
  cost_df = cost_df.withColumn(col, f.regexp_replace(f.col(col), ',', ''))
  cost_df = cost_df.withColumn(col, f.col(col).cast(FloatType()))

# write our table out to our catalog and schema
cost_df.write.mode('overwrite').saveAsTable(f'{catalog}.{schema}.llm_proprietary_costs')
display(cost_df)