## Anthropic Stress Test Demo
-------------

### Purpose
- Create randomized datasets of various sizes and test different LLM's performance against them.
  - Standardize / track input and output tokens
  - Summarize each row with AI query
  - Capture a **rough estimate** of DBU consumption, cost, and total Runtime
    - *Actual costs may differ from what this process indicates*

----------------------
### Links:
- https://www.databricks.com/product/pricing/proprietary-foundation-model-serving
- https://docs.databricks.com/aws/en/large-language-models/ai-functions#-general-purpose-function-ai_query

-----------------------

### Input Parameters
- **Catalog**
  - The desired catalog for results to be written to upon completion.
- **Schema**
  - The desired schema within the output catalog for results to be written.
- **Model Endpoints to Run**
  - A comma seperated list of endpoints available to you in databricks.
    - e.g. "databricks-claude-sonnet-4-5, databricks-claude-opus-4-5, databricks-claude-opus-4-1"
      - ***exclude the quotes when inputing data***
- **Sample Dataset Sizes**
  - A comma seperated list of numbers to determine how large our sample datasets will be. Tables smaller than 100 rows will be filtered out from logging results.
    - e.g. "100, 1000, 10000"
      - ***exclude the quotes when inputing data***
- **Time Zone**
  - The time zone of the region your workspace is in. For example, if your workspace is in AWS us-east-2, set your time zone to something like America/New_York

------------------

## Setup

In [0]:
from pyspark.sql import functions as f
from pyspark.sql import Window
from pyspark.sql.types import StructType, StructField, StringType, TimestampType, IntegerType, DoubleType
from stress_test import create_widgets, generate_sales_data, run_sample_query, to_valid_table_name

create_widgets(dbutils, '02_llm_performance_test')

### Defining Helper Functions
- **generate_sales_data**
  - Creates a randomized dataset of a specified size with the following features:
    - *date*: the date a transaction occurred
    - *average_temperature*: the temp in celsius for the day of the transaction
    - *rainfall*: the rainfall in cm for the day of the transaction
    - *weekend*: if the transaction was on a weekend
    - *holiday*: if the transaction was on a holiday
    - *price_per_kg*: the price per kg of goods for the transaction
    - *demand*: the amount of goods purchased in kgs
    - *month*: the month number of the transaction
    - *total_spend*: the total amount spent on the transaction

- **run_sample_query**
  - Executes a given sql query and prints the amount of time to completion. Also can optionally write the output of the query to a specified location if provided.

- **to_valid_table_name**
  - Converts a given string into a valid databricks table name by performing a series of regex manipulations.

-----------------
### Get the values from our widgets

In [0]:
# Set the sizes for our randomized datasets
record_counts = [int(num.strip()) for num in dbutils.widgets.get('sample_dataset_sizes').split(',')]

# Set the endpoints for the models we want to use
model_endpoints = [model.strip() for model in dbutils.widgets.get('model_endpoints').split(',')]

# Set the catalog and schema for the outputs
catalog = dbutils.widgets.get('catalog')
schema = dbutils.widgets.get('schema')

# get timezone for spark jobs
time_zone = dbutils.widgets.get('time_zone')
spark.conf.set('spark.sql.session.timeZone', time_zone)

# The prompt we want to give our ai_query
prompt = """
You will be given a dataset with the following fields:
- date: the date a transaction occurred
- average_temperature: the temp in celsius for the day of the transaction
- rainfall: the rainfall in cm for the day of the transaction
- weekend: if the transaction was on a weekend
- holiday: if the transaction was on a holiday
- price_per_kg: the price per kg of goods for the transaction
- demand: the amount of goods purchased in kgs
- month: the month number of the transaction
- total_spend: the total amount spent on the transaction

Provide me with concise observations (2 sentences max) about the following transactions:
"""

# SQL to execute, leave space open for model and prompt
sql = """
        SELECT
            *,
            ai_query(
                '{model}',
                '{prompt}' ||
                concat_ws(', ',
                    CAST(date AS STRING),
                    CAST(average_temperature AS STRING),
                    CAST(rainfall AS STRING),
                    CAST(weekend AS STRING),
                    CAST(holiday AS STRING),
                    CAST(price_per_kg AS STRING),
                    CAST(demand AS STRING),
                    CAST(month AS STRING),
                    CAST(total_spend AS STRING)
                )
            ) AS summary
        FROM sample_records
    """

-----------------------

## Executing the Queries

- The process:
  1. Loop through each data set size from ***record_counts***
      - Generate a dataset with given number of rows and create a temporary view to be used in the sql query
  2. Loop through each model in ***model_endpoints***
      - For every dataset size we will run each model specified to compare results later.
  3. Write query results
      - Save results from queries for future reference

In [0]:
# Create an empty list to store query information
run_logs = []

# Loop through each data set size
for record_count in record_counts:
    print(f'Generating {record_count} records...\n')
    sample_records = generate_sales_data(record_count)
    spark.createDataFrame(sample_records) \
        .createOrReplaceTempView('sample_records')
    
    # Query the dataset with each model
    for model in model_endpoints:
        print(f'Running model: {model}')
        run_name = to_valid_table_name(model)
        table_name = f'{run_name}_results_{record_count}'
        run_sql = sql.format(model=model, prompt=prompt)
        
        # Attempt to execute the query and log the results
        try:
            run_log = run_sample_query(spark, run_sql, catalog, schema, table_name)
            run_log['model'] = model
            run_log['run_name'] = run_name
            
            run_logs.append(run_log)
            print('Elapsed time:', run_log['elapsed_time'])
            print('-' * 80, '\n')
        except Exception as e:
            print(e)

In [0]:
# Log our runs
logging_schema = StructType([
    StructField('run_id', StringType(), True),
    StructField('run_name', StringType(), True),
    StructField('model', StringType(), True),
    StructField('output_table', StringType(), True),
    StructField('n_rows', IntegerType(), True),
    StructField('n_columns', IntegerType(), True),
    StructField('start_time', TimestampType(), True),
    StructField('end_time', TimestampType(), True),
    StructField('elapsed_time', DoubleType(), True),
    StructField('sql', StringType(), True),
])

logging_df = spark.createDataFrame(run_logs, schema=logging_schema)
logging_df = logging_df.withColumn('insert_time', f.current_timestamp())
logging_df.write.mode('append').option('mergeSchema','true').saveAsTable(f'{catalog}.{schema}.runs')

-------------------

## View the Results

In [0]:
%sql
SELECT * FROM identifier(:catalog || '.' || :schema || '.runs')
ORDER BY start_time DESC

--------------

## Analyze the Queries with Databricks System Tables

- We need to track our query performance and understand how each model performs on identical datasets.

- The first step is to find our queries in the **system.queries.history** table .
  - This will allow us to identify start / end times for our queries and analyze run times.

In [0]:
%sql
CREATE OR REPLACE TEMPORARY VIEW model_queries AS
SELECT 
  account_id,
  workspace_id,
  statement_id,
  executed_by,
  total_duration_ms,
  execution_duration_ms,
  compilation_duration_ms,
  result_fetch_duration_ms,
  start_time,
  end_time
FROM system.query.history
WHERE
  statement_text = "llm_observer_sdf.write.mode('overwrite').saveAsTable(f'{llm_observer_catalog}.{llm_observer_schema}.{llm_observer_table_name}')"
ORDER BY
  start_time DESC
;

SELECT * FROM model_queries

- Now we can join the data from the 2 views together. We still don't have a concrete id to map them together, but we can link them together based on their start times.

- The query starts just seconds before the first model run starts. Because these models take so long to complete, we can assume that if there is a query start and a model run within 5 seconds of each other, those are the same record.

In [0]:
# Construct the table name in Python
table_name = f"{catalog}.{schema}.runs"

# Use the constructed table name in the SQL query
query = f"""
CREATE OR REPLACE TEMPORARY VIEW run_query_mapping AS
SELECT
  t2.executed_by,
  t1.run_id,
  t1.output_table,
  t1.n_rows,
  t1.start_time as run_start_time,
  t1.end_time as run_end_time,
  t2.start_time AS query_start_time,
  t2.end_time AS query_end_time,
  t2.total_duration_ms / 60000 AS total_duration_minutes,
  t2.statement_id
FROM {table_name} t1
INNER JOIN model_queries t2 ON
  abs(unix_timestamp(t1.start_time) - unix_timestamp(t2.start_time)) <= 5
WHERE 
  t1.n_rows >= 100
ORDER BY
  t1.start_time DESC
"""

# Run the query
spark.sql(query)

# Display the results
display(spark.sql("SELECT * FROM run_query_mapping"))

-------------
### Get data about the batch processing calls

In [0]:
%sql
CREATE OR REPLACE TEMPORARY VIEW token_count AS
SELECT 
  t1.account_id,
  t1.workspace_id,
  t1.served_entity_id,
  t2.endpoint_name,
  t1.requester,
  t1.request_time,
  t1.input_token_count,
  t1.output_token_count,
  t1.input_character_count,
  t1.output_character_count
FROM system.serving.endpoint_usage t1
INNER JOIN system.serving.served_entities t2 ON
  t1.served_entity_id = t2.served_entity_id
;

SELECT * FROM token_count

------------

## Get cost data and filter for results
- Models may have multiple cost configurations depending on usage, region, etc.
  - The system tables don't give us enough information about region and context window so **we make assumptions**
- This process gets the first available cost configuration for a model
  - ***This cost configuration may not be 100% accurate to your given use case***
  - Regardless, this cost estimate should be close to the actual values

In [0]:
# get data for natively served model cost
cost_df = spark.sql(f'select * from {catalog}.{schema}.llm_cost_mapping')

# window over model name
window = Window.partitionBy('model').orderBy(['model','endpoint_type','context_length'])

# get row count by model name and filter for the first record returned
cost_df = cost_df.withColumn('row_number', f.row_number().over(window))
cost_df = cost_df.filter(f.col('row_number')==1)
cost_df.createOrReplaceTempView('cost_df')
display(cost_df)

-----------

## Save final table

In [0]:
%sql
CREATE OR REPLACE TABLE identifier(:catalog || '.' || :schema || '.model_testing_results') AS
SELECT 
  t5.*,
  t5.input_dbus + t5.output_dbus as total_dbus,
  t5.input_dbus * 0.07 as input_dbu_cost,
  t5.output_dbus * 0.07 as output_dbu_cost,
  (t5.input_dbus + t5.output_dbus) * 0.07 as total_dbu_cost
FROM (
  SELECT 
    t4.*,
    t4.input_token_count / 1000000 * t4.input_dbu_1m_tokens as input_dbus,
    t4.output_token_count / 1000000 * t4.output_dbu_1m_tokens as output_dbus
  FROM (
    SELECT 
      t1.*,
      t2.requester,
      t2.request_time,
      t2.served_entity_id,
      t2.endpoint_name,
      t2.input_token_count,
      t2.output_token_count,
      t2.input_token_count + t2.output_token_count AS total_tokens,
      t2.input_character_count,
      t2.output_character_count,
      t2.input_character_count + t2.output_character_count AS total_characters,
      t3.input_dbu_1m_tokens,
      t3.output_dbu_1m_tokens
    FROM run_query_mapping t1
    INNER JOIN token_count t2 ON
      t2.request_time > t1.query_start_time 
      AND t2.request_time < t1.query_end_time
      AND t1.executed_by = t2.requester
    INNER JOIN cost_df t3 ON
      t2.endpoint_name = t3.served_endpoint_name
  ) t4
) t5
WHERE
  t5.input_dbu_1m_tokens IS NOT NULL
;

SELECT * FROM identifier(:catalog || '.' || :schema || '.model_testing_results')