# SynapseML OpenAI Test Notebook

Build and run steps:
- `sbt clean package`
- `sbt packageSynapseML`
- `pip install target/generated/package/python/<wheel>`
- `pip install target/scala-2.12/generated/package/python/synapseml-1.0.15.dev1-py2.py3-none-any.whl`
- Start Jupyter and run this notebook.

Configuration is read from `.env` in the repo root: `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_API_KEY`, `OPENAI_DEPLOYMENT`, `OPENAI_EMBED_DEPLOYMENT`.

In [1]:
import json
import subprocess
import sys
import os

from pyspark.sql import SparkSession

# Stop any existing Spark session
try:
    spark.stop()
except:
    pass

local_jar_path = subprocess.check_output(
    "ls ../../../target/scala-2.12/synapseml_2.12-*-SNAPSHOT.jar | tail -1",
    shell=True,
    text=True,
).strip()

assert local_jar_path, "No JAR found."


snapshot_version = local_jar_path.split("synapseml_2.12-")[-1].split(".jar")[0]

print(f"Using local JAR: {local_jar_path}\n")
print(f"Using snapshot version: {snapshot_version}\n")

# # Configure Spark to avoid transport conflicts
os.environ['SPARK_LOCAL_IP'] = '127.0.0.1'

# spark = (SparkSession.builder
#     .appName("MyApp")
#     .config("spark.jars", local_jar_path)
#     .getOrCreate())

# spark = (SparkSession.builder
#     .appName("MyApp")
#     .config("spark.jars", local_jar_path)
#     .config("spark.jars.packages", f"com.microsoft.azure:synapseml_2.12:{snapshot_version}")
#     .getOrCreate())

spark = (SparkSession.builder
    .appName("synapseml")
    .config("spark.jars.packages", f"com.microsoft.azure:synapseml_2.12:{snapshot_version}")
    .config("spark.jars.repositories", "https://mmlspark.azureedge.net/maven")
    .config("spark.jars.excludes", "org.scala-lang:scala-reflect,org.apache.spark:spark-tags_2.12,org.scalactic:scalactic_2.12,org.scalatest:scalatest_2.12,com.fasterxml.jackson.core:jackson-databind")
    .config("spark.{driver,executor}.userClassPathFirst", "true")
    .config("spark.sql.parquet.enableVectorizedReader", "false")
    .getOrCreate())

print("✓ Spark session created successfully!")

ls: ../../../target/scala-2.12/synapseml_2.12-*-SNAPSHOT.jar: No such file or directory


AssertionError: No JAR found.

In [None]:
# Load configuration from .env in repo root
import os
def load_env(path='../../../.env'):
    if not os.path.exists(path):
        raise FileNotFoundError('Missing .env at repo root')
    with open(path, 'r') as f:
        for line in f:
            s = line.strip()
            if not s or s.startswith('#'):
                continue
            if '=' in s:
                k, v = s.split('=', 1)
                os.environ[k.strip()] = v.strip()

load_env()
required = ['AZURE_OPENAI_ENDPOINT','AZURE_OPENAI_API_KEY','OPENAI_DEPLOYMENT','OPENAI_EMBED_DEPLOYMENT']
missing = [k for k in required if not os.environ.get(k)]
if missing:
    raise RuntimeError(f'Missing keys in .env: {missing}')
print('Loaded .env configuration.')

In [None]:
# Quick tests: OpenAIPrompt (Responses API) and OpenAIEmbedding
# sys.path.append("../../../target/scala-2.12/generated/src/python")
from synapse.ml.services.openai.OpenAIDefaults import OpenAIDefaults
from synapse.ml.services.openai.OpenAIPrompt import OpenAIPrompt
from synapse.ml.services.openai.OpenAIEmbedding import OpenAIEmbedding

defaults = OpenAIDefaults()
defaults.set_URL(os.environ['AZURE_OPENAI_ENDPOINT'])
defaults.set_subscription_key(os.environ['AZURE_OPENAI_API_KEY'])
defaults.set_deployment_name(os.environ['OPENAI_DEPLOYMENT'])
defaults.set_embedding_deployment_name(os.environ['OPENAI_EMBED_DEPLOYMENT'])
defaults.set_temperature(0.1)

df = spark.createDataFrame([('apple','fruits'),('mercedes','cars'),('cake','dishes')], ['text','category'])

prompt = (OpenAIPrompt()
          .setPromptTemplate('Complete a comma-separated list of 5 {category}: {text}, ')
          .setResponseFormat('text')
          .setOutputCol('out'))
out_df = prompt.transform(df)
out_df.select('out').show(truncate=False)

# defaults.set_deployment_name(os.environ['OPENAI_EMBED_DEPLOYMENT'])
emb = OpenAIEmbedding().setTextCol('text').setOutputCol('embedding')
emb_df = emb.transform(df)
emb_df.select('text','embedding').show(truncate=False)

                                                                                

+-----------------------------------+
|out                                |
+-----------------------------------+
|apple, banana, orange, grape, mango|
|mercedes, bmw, audi, toyota, honda |
|cake, pizza, sushi, pasta, salad   |
+-----------------------------------+



                                                                                

+--------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

                                                                                