Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,8 @@ class OpenAIPrompt(override val uid: String) extends Transformer
val columnTypes = new StringStringMapParam(
this, "columnTypes", "A map from column names to their types. Supported types are 'text' and 'path'.")
private def validateColumnType(value: String) = {
if (value.equalsIgnoreCase("path") || value.equalsIgnoreCase("text")) {
logWarning(s"Column type '$value' is deprecated. Please use lowercase 'path' or 'text' instead.")
}
require(value == "text" || value == "path",
s"Unsupported column type: $value. Supported types are 'text' and 'path'.")
require(value != "responses" || this.getApiType == "responses",
s"Column type 'path' is only supported when apiType is set to 'responses'.")
}

def getColumnTypes: Map[String, String] = $(columnTypes)
Expand All @@ -208,6 +203,27 @@ class OpenAIPrompt(override val uid: String) extends Transformer
def setColumnTypes(v: java.util.HashMap[String, String]): this.type =
set(columnTypes, v.asScala.toMap)

val fileSizeLimitMB: Param[Double] = new Param[Double](
this, "fileSizeLimitMB",
"Maximum file size in megabytes for path columns. Files exceeding this limit will produce an error.")

def getFileSizeLimitMB: Double = $(fileSizeLimitMB)

private var fileSizeLimitBytes: Long = 0L

def setFileSizeLimitMB(value: Double): this.type = {
require(value > 0, "File size limit must be positive")
fileSizeLimitBytes = (value * 1024 * 1024).toLong
set(fileSizeLimitMB, value)
}

private def getFileSizeLimitBytes: Long = {
if (fileSizeLimitBytes == 0L && isSet(fileSizeLimitMB)) {
fileSizeLimitBytes = (getFileSizeLimitMB * 1024 * 1024).toLong
}
fileSizeLimitBytes
}

private val defaultSystemPrompt = "You are an AI chatbot who wants to answer user's questions and complete tasks. " +
"Follow their instructions carefully and be brief if they don't say otherwise."

Expand Down Expand Up @@ -266,9 +282,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
private def configureService(
service: OpenAIServicesBase with HasTextOutput,
df: DataFrame,
promptCol: Column,
createMessagesUDF: UserDefinedFunction,
attachmentsColumn: Column
messagesCol: Column
): (DataFrame, String, OpenAIServicesBase with HasTextOutput) = {
// All services are now HasMessagesInput (OpenAIChatCompletion, OpenAIResponses, AIFoundryChatCompletion)
// Legacy OpenAICompletion did not support MessagesInput which is no longer used in this class.
Expand All @@ -285,7 +299,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
val messageColName = getMessagesCol

(
df.withColumn(messageColName, createMessagesUDF(promptCol, attachmentsColumn)),
df.withColumn(messageColName, messagesCol),
messageColName,
messagesService.setMessagesCol(messageColName).asInstanceOf[OpenAIServicesBase with HasTextOutput]
)
Expand Down Expand Up @@ -413,21 +427,30 @@ class OpenAIPrompt(override val uid: String) extends Transformer
typedLit(Map.empty[String, String])
}

val createMessagesUDF = udf[Seq[OpenAICompositeMessage], String, Map[String, String]] {
// UDF returns (messages, error)
val createMessagesUDF = udf[(Seq[OpenAICompositeMessage], String), String, Map[String, String]] {
(userMessage, attachmentMap) =>
if (userMessage == null) {
null //scalastyle:ignore null
} else {
if (userMessage == null) (null, null) //scalastyle:ignore null
else Try {
createMessagesForRow(
userMessage,
Option(attachmentMap).getOrElse(Map.empty[String, String]),
pathColumnNames
)
} match {
case scala.util.Success(msgs) => (msgs, null) //scalastyle:ignore null
case scala.util.Failure(e) => (null, e.getMessage) //scalastyle:ignore null
}
}

val fileResultCol = createMessagesUDF(promptCol, attachmentsColumn)
val fileErrorStruct = toErrorStruct(fileResultCol.getField("_2"))
val dfWithFile = dfWithFilenames
.withColumn(getMessagesCol, fileResultCol.getField("_1"))
.withColumn(getErrorCol, fileErrorStruct)

val (dfTemplated, inputColName, serviceConfigured) =
configureService(service, dfWithFilenames, promptCol, createMessagesUDF, attachmentsColumn)
configureService(service, dfWithFile, F.col(getMessagesCol))
val result = generateText(serviceConfigured, dfTemplated)

val resultCleaned = filenameColMapping.values.foldLeft(result) { (df, colName) =>
Expand All @@ -438,6 +461,22 @@ class OpenAIPrompt(override val uid: String) extends Transformer
}, dataset.columns.length)
}

private def toErrorStruct(errorStr: Column): Column = {
val statusSchema = T.StructType(Seq(
T.StructField("protocolVersion", T.StructType(Seq(
T.StructField("protocol", T.StringType),
T.StructField("major", T.IntegerType),
T.StructField("minor", T.IntegerType)
))),
T.StructField("statusCode", T.IntegerType),
T.StructField("reasonPhrase", T.StringType)
))
F.when(errorStr.isNotNull, F.struct(
errorStr.as("response"),
F.lit(null).cast(statusSchema).as("status") //scalastyle:ignore null
))
}

private[openai] def stringMessageWrapper(str: String): Map[String, String] = {
if (this.getApiType == "responses") {
Map("type" -> "input_text", "text" -> str)
Expand Down Expand Up @@ -488,6 +527,13 @@ class OpenAIPrompt(override val uid: String) extends Transformer
private def prepareFile(filePathStr: String): (String, Array[Byte], String, String) = {
val filePath = new HPath(filePathStr)
val fileBytes = BinaryFileReader.readSingleFileBytes(filePath)

if (isSet(fileSizeLimitMB) && fileBytes.length > getFileSizeLimitBytes) {
val fileSizeMB = fileBytes.length / (1024.0 * 1024.0)
throw new IllegalArgumentException(
f"File '$filePathStr' size $fileSizeMB%.2f MB exceeds limit ${getFileSizeLimitMB}%.2f MB")
}

val fileName = filePath.getName
val extension = fileName.lastIndexOf('.') match {
case idx if idx >= 0 => fileName.substring(idx + 1).toLowerCase
Expand Down Expand Up @@ -520,7 +566,9 @@ class OpenAIPrompt(override val uid: String) extends Transformer
)
case "audio" =>
throw new IllegalArgumentException("Audio input is not supported in the current API version.")
case _ =>
case "unsupported" =>
throw new IllegalArgumentException(s"Unsupported file type: $mimeType.")
case "file" =>
Map(
"type" -> "input_file",
"filename" -> fileName,
Expand All @@ -539,7 +587,9 @@ class OpenAIPrompt(override val uid: String) extends Transformer
case "text" =>
stringMessageWrapper(s"Content: ${new String(fileBytes, StandardCharsets.UTF_8)}")
case _ =>
throw new IllegalArgumentException(s"Multimodal Input is not supported in Chat Completions API.")
throw new IllegalArgumentException(
s"File type $mimeType is not supported in Chat Completions API. " +
"Only text files are supported. Use apiType='responses' for file input.")
}
}

Expand All @@ -560,7 +610,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
else if (mimeType.startsWith("image/") && imageExtensions.contains(extension)) "image"
else if (mimeType.startsWith("audio/") && audioExtensions.contains(extension)) "audio"
else if (mimeType.startsWith("text/") || textExtensions.contains(extension)) "text"
else "file"
else "unsupported"
}

private[openai] def hasAIFoundryModel: Boolean = this.isDefined(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,92 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
assert(results(2).getString(0) != null)
}

test("file preparation error stores error in errorCol instead of failing") {
// Create a valid temp file for the working rows
val tempFile = Files.createTempFile("synapseml-openai-test", ".txt")
try {
Files.write(tempFile, "test content for file".getBytes(StandardCharsets.UTF_8))

val promptWithPath = new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setDeploymentName(deploymentName)
.setCustomServiceName(openAIServiceName)
.setApiVersion("2025-04-01-preview")
.setApiType("responses")
.setColumnType("filePath", "path")
.setOutputCol("outParsed")
.setPromptTemplate("{questions}: {filePath}")

val testDF = Seq(
("Summarize this file", tempFile.toString), // valid file
("Summarize this file", "/nonexistent/path/to/file.txt"), // non-existent local file
("Summarize this file", tempFile.toString) // valid file
).toDF("questions", "filePath")

val result = promptWithPath.transform(testDF)
val rows = result.select("outParsed", promptWithPath.getErrorCol).collect()
// display result
result.show(false)

// Helper to extract error message from error struct
def getErrorMessage(row: Row, idx: Int): String = {
val errorStruct = row.getAs[Row](idx)
if (errorStruct == null) null else errorStruct.getAs[String]("response") // scalastyle:ignore null
}

// First row: valid file - should have output, no error
assert(rows(0).getString(0) != null, "First row should have output")
assert(rows(0).get(1) == null, "First row should have no error")

// Second row: non-existent file - should have null output and error message
assert(rows(1).get(0) == null, "Second row should have null output for file not found")
val errorMsg = getErrorMessage(rows(1), 1)
assert(errorMsg != null, "Second row should have error message for file not found")
assert(errorMsg.nonEmpty, "Error message should not be empty")

// Third row: valid file - should have output, no error
assert(rows(2).getString(0) != null, "Third row should have output")
assert(rows(2).get(1) == null, "Third row should have no error")
} finally {
Files.deleteIfExists(tempFile)
}
}

test("file size limit produces error when exceeded") {
val tempFile = Files.createTempFile("synapseml-openai-size-test", ".txt")
try {
// Write content that exceeds 0.00001 MB (~10 bytes limit)
Files.write(tempFile, "This content is larger than 10 bytes".getBytes(StandardCharsets.UTF_8))

val promptWithLimit = new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setDeploymentName(deploymentName)
.setCustomServiceName(openAIServiceName)
.setApiVersion("2025-04-01-preview")
.setApiType("responses")
.setColumnType("filePath", "path")
.setFileSizeLimitMB(0.00001) // ~10 bytes limit
.setOutputCol("outParsed")
.setPromptTemplate("{questions}: {filePath}")

val testDF = Seq(
("Summarize this file", tempFile.toString)
).toDF("questions", "filePath")

val result = promptWithLimit.transform(testDF)
val rows = result.select("outParsed", promptWithLimit.getErrorCol).collect()

// Should have null output and error message about size limit
assert(rows(0).get(0) == null, "Should have null output when file exceeds size limit")
val errorStruct = rows(0).getAs[Row](1)
assert(errorStruct != null, "Should have error when file exceeds size limit")
val errorMsg = errorStruct.getAs[String]("response")
assert(errorMsg.contains("exceeds limit"), s"Error should mention exceeds limit: $errorMsg")
} finally {
Files.deleteIfExists(tempFile)
}
}

ignore("Custom EndPoint") {
lazy val accessToken: String = sys.env.getOrElse("CUSTOM_ACCESS_TOKEN", "")
lazy val customRootUrlValue: String = sys.env.getOrElse("CUSTOM_ROOT_URL", "")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.{col, coalesce}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row}

Expand Down Expand Up @@ -133,10 +133,17 @@ class SimpleHTTPTransformer(val uid: String)
.setInputCol(parsedInputCol)
.setOutputCol(unparsedOutputCol))

val parseErrors = Some(Lambda(_
.withColumn(getErrorCol, ErrorUtils.addErrorUDF(col(unparsedOutputCol)))
.withColumn(unparsedOutputCol, ErrorUtils.NullifyResponseUDF(col(getErrorCol), col(unparsedOutputCol)))
))
val parseErrors = Some(Lambda(df => {
val serviceError = ErrorUtils.addErrorUDF(col(unparsedOutputCol))
// Preserve existing errorCol if present (e.g., from prep errors), otherwise use service error
val errorExpr = if (df.columns.contains(getErrorCol)) {
coalesce(col(getErrorCol), serviceError)
} else {
serviceError
}
df.withColumn(getErrorCol, errorExpr)
.withColumn(unparsedOutputCol, ErrorUtils.NullifyResponseUDF(col(getErrorCol), col(unparsedOutputCol)))
}))

val outputParser =
Some(getOutputParser
Expand Down
Loading