Run on Serverless notebook, or on DBR 16.3 or higher

In [0]:
from pyspark.sql.functions import col, expr, regexp_replace
import random

catalog = "main" 
schema = "default" 
source_table = f"{catalog}.{schema}.pilot_notes" 
target_table = f"{catalog}.{schema}.labeled_pilot_notes" 
llm = "databricks-meta-llama-3-3-70b-instruct" # can also try databricks-claude-3-7-sonnet, databricks-meta-llama-3-1-8b-instruct
training_set_size = 1000 # number of rows to limit your LLM to for unsupervised classification

In [0]:
classes_and_descriptions = [
    ("PRE-FLIGHT", "Aircraft and crew prep before engine start. Examples: Inspecting cargo doors; completing pre-flight checklist."),
    ("TAXI_OUT", "Moving aircraft from gate to runway. Examples: Taxi via taxiway B; flaps set for takeoff."),
    ("TAKEOFF", "Aircraft accelerates and lifts off. Examples: Nose up at VR; gear retracted after positive climb."),
    ("CLIMB", "Gaining altitude to reach cruise. Examples: Throttle set to climb power; monitoring rate of ascent."),
    ("CRUISE", "Level flight at cruising altitude. Examples: Monitoring instruments; communicating with ATC."),
    ("DESCENT", "Controlled reduction in altitude. Examples: Initiating descent at TOD; adjusting cabin pressure."),
    ("APPROACH", "Final phase before landing. Examples: Aligning with runway; configuring flaps and gear."),
    ("LANDING", "Touchdown and deceleration. Examples: Main gear contact; deploying spoilers and brakes."),
    ("TAXI_IN", "Moving from runway to parking. Examples: Taxi to gate A4; shutting down unnecessary systems."),
    ("POST-FLIGHT", "Shutdown and inspection after arrival. Examples: Completing logbook; post-flight walkaround."),
    ("UNKNOWN", "Unclassified or ambiguous entry. Examples: Notes not tied to a specific flight phase.")
]

data = [random.choice(classes_and_descriptions) for _ in range(1000)]
df = spark.createDataFrame(data, ["category", "prompt"])

# Use ai_query() to generate pilot notes for each category
synthetic_data_df = (
    df.withColumn(
        "pilot_notes", 
        expr(
            f"ai_query('{llm}', "
            "concat('No introductory lines like here is a summary of... or during this phase.., "
            "just role play a pilot writing very objective, direct, concise notes about the following"
            " topic in about 20 varied words (give or take): ', prompt), "
            "modelParameters => named_struct('max_tokens', 100, 'temperature', 1))"
        )
    )
)
synthetic_data_df.select("category", "pilot_notes").write.mode("append").saveAsTable(source_table)
source_df = spark.read.table(source_table)
source_df.display()

In [0]:
import json
from pyspark.sql.functions import json_tuple

classes = [item[0] for item in classes_and_descriptions]
descriptions = [{"value": item[0], "description": item[1]} for item in classes_and_descriptions]
classification_schema_with_confidence = json.dumps({
    "type": "json_schema",
    "json_schema": {
        "name": "classification",
        "schema": {
            "type": "object",
            "properties": {
                "classification": {
                    "type": "string",
                    "enum": classes, # Restrict outputs to only the allowed classes
                    "enumDescriptions": descriptions # Add descriptions to guide outputs
                },
                "confidence_level": {
                    "type": "number",
                    "description": "Confidence level of the flight phase classification",
                    "minimum": 0,
                    "maximum": 1
                }
            },
            "required": ["classification", "confidence_level"]
        }
    }
})

query = ( # prompt the LLM to classify the pliots notes
    f"ai_query('{llm}', concat('Extract the flight phase classification of these pilot notes: ', pilot_notes), "
    f"responseFormat => '{classification_schema_with_confidence}')"
)

intermediate_df = source_df.withColumn("output", expr(query))
parsed_df = (
    intermediate_df
    .withColumn("prediction", json_tuple("output", "classification"))
    .withColumn("confidence_level", json_tuple("output", "confidence_level"))
    .withColumn("semantic_similarity", expr("ai_similarity(pilot_notes, prediction)"))
    .withColumn("levenshtein_distance", expr("levenshtein(pilot_notes, prediction)"))
)
parsed_df.write.mode("overwrite").saveAsTable(target_table)
spark.read.table(target_table).display()