In [0]:
from pyspark.sql.functions import col
from pyspark.sql.functions import lit
from pyspark.sql import functions as f
from pyspark.sql.functions import regexp_replace
import mlflow.deployments
import mlflow
from pyspark.sql.functions import udf
from pyspark.sql.functions import pandas_udf,StringType
import pandas as pd
import time
from pyspark.sql.functions import element_at, split

In [0]:

template = """You are an expert health advisor evaluating the health status of military pilots based on medical and readiness data. Given the following features and record for a pilot, assess their health status on a scale from 1 (unhealthy) to 10 (healthy). Consider factor like age, diagnosis, history, mental health status, prescriptions, lab results, and procedural compliance.
**Pilot Information**
- Age: {age}
- Gender: {gender}
- Rank: {rank}
- Flight Hours Total: {flight_hours_total}
- Flight Hours Last 12 Months: {flight_hours_last_12mo}
- Aeromedical_Class_Current: {aeromedical_class_current}; Class I is the most stringent and will result in a higher probability of a disqualifying event than Class II or Class III. 
- PHA Status: {pha_status}; when Overdue, this could be considered a risk factor
**Medical Records**
- Diagnosis: {diagnosis}
- Medications: {medication}
- Reason: {reason}
- Hospitalization reason: {hospitalization_reason}
- Dental Readiness: {dental_readiness}; a score of 1 means perfectly ready, where a score of 3 is not ready
- Lab results: {test_name} - {result_value} - {unit} ref low {ref_low} ref high {ref_high}
- Doctor Visits: {encounters_6mo}
- Visit Type: {visit_type}
**Mental Health**
- Mental Health Medication: {mental_health_medication}
- Therapy Sessions: {therapy_sessions}

Analyze the data and provide a health status score, considering the severity of the conditions, frequency of health visists, mental health management, and laboratory flags. If fields are left blank assume that the pilot did not have to any condition or procedure. Provide reasoning for your score in the following list format, do not provide any more context but the list. The output should look something like:
    [8.4, 'The reason why the score is 8.4 is because...']

Limit the answer to 200 tokens.
"""

'''def get_all_results(content: str) -> str:
    deploy_client = mlflow.deployments.get_deploy_client("databricks")
    model_name = "databricks-meta-llama-3-1-8b-instruct"
    inputs = {
    "messages": [
    {
        "role": "assistant",
        "content": content
    }
    ],
    }
    response = deploy_client.predict(endpoint = model_name, inputs = inputs)
    response = response['choices'][0]['message']['content']
    return response

request_results = udf(get_all_results, returnType=StringType())'''
def get_results(batch):
    model_name = "databricks-meta-llama-3-1-8b-instruct"
    deploy_client = mlflow.deployments.get_deploy_client("databricks")
    inputs = {
    "messages": [
    {
        "role": "user",
        "content": batch
    }
    ], "max_tokens" : 200
    }
    response = deploy_client.predict(endpoint = model_name, inputs = inputs)
    response = response['choices'][0]['message']['content']
    time.sleep(.1)
    return response

@pandas_udf("string")
def get_all_results(content: pd.Series) -> pd.Series:
    max_batch_size = 50
    all_results = []
    for i in range(0, len(content, batch_size)):
        batch = prompts[i:i+batch_size].tolist()
        inputs = [{"role": "user", "content": b} for p in batch]
        all_results += get_results(batch)
    return pd.Series(all_results)




@f.udf("string")
def render_prompt(m: dict)-> str:
    return template.format(**m)



def generate_json():
    df = spark.table("avengers.default.all_pilots_data")
    profile_columns = ["age","gender",'rank','flight_hours_total','flight_hours_last_12mo',"diagnosis","aeromedical_class_current","pha_status","medication","reason","hospitalization_reason","dental_readiness","test_name","result_value","unit", "ref_low","ref_high","encounters_6mo","visit_type","mental_health_medication","therapy_sessions"]
    sections_struct = [f.col(col).alias(col) for col in profile_columns]
    record_map = f.create_map(*sum([[f.lit(c), f.coalesce(f.col(c).cast("string"), f.lit(""))] for c in profile_columns],[]))
    df = df.withColumn('content', f.to_json(f.struct(*sections_struct)))
    df = df.withColumn('prompt', render_prompt(record_map))
    df.write \
                            .format('delta') \
                            .option('maxRecordsPerFile', 2000) \
                            .option("overwriteSchema", "false") \
                            .mode('overwrite') \
                            .saveAsTable("avengers.default.ai_generated_dataset_v3")
    df = spark.sql("SELECT *,ai_query('databricks-gpt-oss-120b', prompt) as result FROM avengers.default.ai_generated_dataset_v3")
    df = df.withColumn("readiness_score",element_at(split(df.result, ","), 1))\
        .withColumn("readiness_score",regexp_replace("readiness_score", "^\\[", ""))
    print("Writing Final Table")
    df.write \
                        .format('delta') \
                        .option('maxRecordsPerFile', 2000) \
                        .option("overwriteSchema", "false") \
                        .mode('overwrite') \
                        .saveAsTable("avengers.default.ai_generated_dataset_complete_v3")
    return(df)

In [0]:
df = generate_json()
