-
Notifications
You must be signed in to change notification settings - Fork 821
/
OpenAIChatCompletion.scala
81 lines (58 loc) · 3.02 KB
/
OpenAIChatCompletion.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.
package com.microsoft.azure.synapse.ml.cognitive.openai
import com.microsoft.azure.synapse.ml.cognitive.{
CognitiveServicesBase, HasCognitiveServiceInput, HasInternalJsonOutputParser
}
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import spray.json._
import spray.json.DefaultJsonProtocol._
import scala.language.existentials
object OpenAIChatCompletion extends ComplexParamsReadable[OpenAIChatCompletion]
class OpenAIChatCompletion(override val uid: String) extends CognitiveServicesBase(uid)
with HasOpenAITextParams with HasCognitiveServiceInput
with HasInternalJsonOutputParser with SynapseMLLogging {
logClass()
val messagesCol: Param[String] = new Param[String](
this, "messagesCol", "The column messages to generate chat completions for," +
" in the chat format. This column should have type Array(Struct(role: String, content: String)).")
def getMessagesCol: String = $(messagesCol)
def setMessagesCol(v: String): this.type = set(messagesCol, v)
def this() = this(Identifiable.randomUID("OpenAIChatCompletion"))
def urlPath: String = ""
override private[ml] def internalServiceType: String = "openai"
override def setCustomServiceName(v: String): this.type = {
setUrl(s"https://$v.openai.azure.com/" + urlPath.stripPrefix("/"))
}
override protected def prepareUrlRoot: Row => String = { row =>
s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/chat/completions"
}
override protected def prepareEntity: Row => Option[AbstractHttpEntity] = {
r =>
lazy val optionalParams: Map[String, Any] = getOptionalParams(r)
val messages = r.getAs[Seq[Row]](getMessagesCol)
Some(getStringEntity(messages, optionalParams))
}
override val subscriptionKeyHeaderName: String = "api-key"
override def shouldSkip(row: Row): Boolean =
super.shouldSkip(row) || Option(row.getAs[Row](getMessagesCol)).isEmpty
override protected def getVectorParamMap: Map[String, String] = super.getVectorParamMap
.updated("messages", getMessagesCol)
override def responseDataType: DataType = ChatCompletionResponse.schema
private[this] def getStringEntity(messages: Seq[Row], optionalParams: Map[String, Any]): StringEntity = {
val mappedMessages: Seq[Map[String, String]] = messages.map { m =>
Seq("role", "content", "name").map(n =>
n -> Option(m.getAs[String](n))
).toMap.filter(_._2.isDefined).mapValues(_.get)
}
val fullPayload = optionalParams.updated("messages", mappedMessages)
new StringEntity(fullPayload.toJson.compactPrint, ContentType.APPLICATION_JSON)
}
}