Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ sealed trait ChatRole extends EnumValue {
}

object ChatRole {
case object System extends ChatRole
case object User extends ChatRole
case object Assistant extends ChatRole

def allValues: Seq[ChatRole] = Seq(User, Assistant)
def allValues: Seq[ChatRole] = Seq(System, User, Assistant)
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,20 @@ import io.cequence.openaiscala.anthropic.domain.Content.{
sealed abstract class Message private (
val role: ChatRole,
val content: Content
)
) {
def isSystem: Boolean = role == ChatRole.System
}

object Message {

case class SystemMessage(
contentString: String,
cacheControl: Option[CacheControl] = None
) extends Message(ChatRole.System, SingleString(contentString, cacheControl))

case class SystemMessageContent(contentBlocks: Seq[ContentBlockBase])
extends Message(ChatRole.System, ContentBlocks(contentBlocks))

case class UserMessage(
contentString: String,
cacheControl: Option[CacheControl] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package io.cequence.openaiscala.anthropic.service

import akka.NotUsed
import akka.stream.scaladsl.Source
import io.cequence.openaiscala.anthropic.domain.{Content, Message}
import io.cequence.openaiscala.anthropic.domain.Message
import io.cequence.openaiscala.anthropic.domain.response.{
ContentBlockDelta,
CreateMessageResponse
Expand Down Expand Up @@ -32,7 +32,6 @@ trait AnthropicService extends CloseableService with AnthropicServiceConsts {
* <a href="https://docs.anthropic.com/claude/reference/messages_post">Anthropic Doc</a>
*/
def createMessage(
system: Option[Content],
messages: Seq[Message],
settings: AnthropicCreateMessageSettings = DefaultSettings.CreateMessage
): Future[CreateMessageResponse]
Expand All @@ -55,7 +54,6 @@ trait AnthropicService extends CloseableService with AnthropicServiceConsts {
* <a href="https://docs.anthropic.com/claude/reference/messages_post">Anthropic Doc</a>
*/
def createMessageStreamed(
system: Option[Content],
messages: Seq[Message],
settings: AnthropicCreateMessageSettings = DefaultSettings.CreateMessage
): Source[ContentBlockDelta, NotUsed]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import akka.NotUsed
import akka.stream.scaladsl.Source
import io.cequence.openaiscala.OpenAIScalaClientException
import io.cequence.openaiscala.anthropic.JsonFormats
import io.cequence.openaiscala.anthropic.domain.Message.{SystemMessage, SystemMessageContent}
import io.cequence.openaiscala.anthropic.domain.response.{
ContentBlockDelta,
CreateMessageResponse
Expand Down Expand Up @@ -33,20 +34,17 @@ private[service] trait AnthropicServiceImpl extends Anthropic {
private val logger = LoggerFactory.getLogger("AnthropicServiceImpl")

override def createMessage(
system: Option[Content],
messages: Seq[Message],
settings: AnthropicCreateMessageSettings
): Future[CreateMessageResponse] =
execPOST(
EndPoint.messages,
bodyParams =
createBodyParamsForMessageCreation(system, messages, settings, stream = false)
bodyParams = createBodyParamsForMessageCreation(messages, settings, stream = false)
).map(
_.asSafeJson[CreateMessageResponse]
)

override def createMessageStreamed(
system: Option[Content],
messages: Seq[Message],
settings: AnthropicCreateMessageSettings
): Source[ContentBlockDelta, NotUsed] =
Expand All @@ -55,7 +53,7 @@ private[service] trait AnthropicServiceImpl extends Anthropic {
EndPoint.messages.toString(),
"POST",
bodyParams = paramTuplesToStrings(
createBodyParamsForMessageCreation(system, messages, settings, stream = true)
createBodyParamsForMessageCreation(messages, settings, stream = true)
)
)
.map { (json: JsValue) =>
Expand Down Expand Up @@ -83,36 +81,42 @@ private[service] trait AnthropicServiceImpl extends Anthropic {
.collect { case Some(delta) => delta }

private def createBodyParamsForMessageCreation(
system: Option[Content],
messages: Seq[Message],
settings: AnthropicCreateMessageSettings,
stream: Boolean
): Seq[(Param, Option[JsValue])] = {
assert(messages.nonEmpty, "At least one message expected.")
assert(messages.head.role == ChatRole.User, "First message must be from user.")

val messageJsons = messages.map(Json.toJson(_))
val (system, nonSystem) = messages.partition(_.isSystem)

val systemJson = system.map {
case Content.SingleString(text, cacheControl) =>
assert(nonSystem.head.role == ChatRole.User, "First non-system message must be from user.")
assert(
system.size <= 1,
"System message can be only 1. Use SystemMessageContent to include more content blocks."
)

val messageJsons = nonSystem.map(Json.toJson(_))

val systemJson: Seq[JsValue] = system.map {
case SystemMessage(text, cacheControl) =>
if (cacheControl.isEmpty) JsString(text)
else {
val blocks =
Seq(Content.ContentBlockBase(Content.ContentBlock.TextBlock(text), cacheControl))

Json.toJson(blocks)(Writes.seq(contentBlockBaseWrites))
}
case Content.ContentBlocks(blocks) =>
Json.toJson(blocks)(Writes.seq(contentBlockBaseWrites))
case Content.ContentBlockBase(content, cacheControl) =>
val blocks = Seq(Content.ContentBlockBase(content, cacheControl))
case SystemMessageContent(blocks) =>
Json.toJson(blocks)(Writes.seq(contentBlockBaseWrites))
}

jsonBodyParams(
Param.messages -> Some(messageJsons),
Param.model -> Some(settings.model),
Param.system -> system.map(_ => systemJson),
Param.system -> {
if (system.isEmpty) None
else Some(systemJson.head)
},
Param.max_tokens -> Some(settings.max_tokens),
Param.metadata -> { if (settings.metadata.isEmpty) None else Some(settings.metadata) },
Param.stop_sequences -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ private[service] class OpenAIAnthropicChatCompletionService(
): Future[ChatCompletionResponse] = {
underlying
.createMessage(
toAnthropicSystemMessages(messages, settings),
toAnthropicMessages(messages, settings),
toAnthropicSystemMessages(messages.filter(_.isSystem), settings) ++
toAnthropicMessages(messages.filter(!_.isSystem), settings),
toAnthropicSettings(settings)
)
.map(toOpenAI)
Expand All @@ -65,8 +65,8 @@ private[service] class OpenAIAnthropicChatCompletionService(
): Source[ChatCompletionChunkResponse, NotUsed] =
underlying
.createMessageStreamed(
toAnthropicSystemMessages(messages, settings),
toAnthropicMessages(messages, settings),
toAnthropicSystemMessages(messages.filter(_.isSystem), settings) ++
toAnthropicMessages(messages.filter(!_.isSystem), settings),
toAnthropicSettings(settings)
)
.map(toOpenAI)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package io.cequence.openaiscala.anthropic.service
import io.cequence.openaiscala.anthropic.domain.CacheControl.Ephemeral
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlock.TextBlock
import io.cequence.openaiscala.anthropic.domain.Content.{ContentBlockBase, ContentBlocks}
import io.cequence.openaiscala.anthropic.domain.Message.SystemMessageContent
import io.cequence.openaiscala.anthropic.domain.response.CreateMessageResponse.UsageInfo
import io.cequence.openaiscala.anthropic.domain.response.{
ContentBlockDelta,
Expand All @@ -21,7 +22,6 @@ import io.cequence.openaiscala.domain.response.{
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettingsOps.RichCreateChatCompletionSettings
import io.cequence.openaiscala.domain.{
AssistantMessage,
ChatRole,
MessageSpec,
SystemMessage,
Expand All @@ -30,7 +30,8 @@ import io.cequence.openaiscala.domain.{
ImageURLContent => OpenAIImageContent,
TextContent => OpenAITextContent,
UserMessage => OpenAIUserMessage,
UserSeqMessage => OpenAIUserSeqMessage
UserSeqMessage => OpenAIUserSeqMessage,
AssistantMessage => OpenAIAssistantMessage
}

import java.{util => ju}
Expand All @@ -40,7 +41,7 @@ package object impl extends AnthropicServiceConsts {
def toAnthropicSystemMessages(
messages: Seq[OpenAIBaseMessage],
settings: CreateChatCompletionSettings
): Option[ContentBlocks] = {
): Seq[Message] = {
val useSystemCache: Option[CacheControl] =
if (settings.useAnthropicSystemMessagesCache) Some(Ephemeral) else None

Expand All @@ -55,7 +56,8 @@ package object impl extends AnthropicServiceConsts {
}
}

if (messageStrings.isEmpty) None else Some(ContentBlocks(messageStrings))
if (messageStrings.isEmpty) Seq.empty
else Seq(SystemMessageContent(messageStrings))
}

def toAnthropicMessages(
Expand All @@ -67,6 +69,8 @@ package object impl extends AnthropicServiceConsts {
case OpenAIUserMessage(content, _) => Message.UserMessage(content)
case OpenAIUserSeqMessage(contents, _) =>
Message.UserMessageContent(contents.map(toAnthropic))
case OpenAIAssistantMessage(content, _) => Message.AssistantMessage(content)

// legacy message type
case MessageSpec(role, content, _) if role == ChatRole.User =>
Message.UserMessage(content)
Expand Down Expand Up @@ -204,7 +208,7 @@ package object impl extends AnthropicServiceConsts {
usage = None
)

def toOpenAIAssistantMessage(content: ContentBlocks): AssistantMessage = {
def toOpenAIAssistantMessage(content: ContentBlocks): OpenAIAssistantMessage = {
val textContents = content.blocks.collect { case ContentBlockBase(TextBlock(text), _) =>
text
} // TODO
Expand All @@ -213,7 +217,7 @@ package object impl extends AnthropicServiceConsts {
throw new IllegalArgumentException("No text content found in the response")
}
val singleTextContent = concatenateMessages(textContents)
AssistantMessage(singleTextContent, name = None)
OpenAIAssistantMessage(singleTextContent, name = None)
}

private def concatenateMessages(messageContent: Seq[String]): String =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package io.cequence.openaiscala.anthropic.service.impl

import akka.actor.ActorSystem
import akka.stream.Materializer
import io.cequence.openaiscala.anthropic.domain.Content.SingleString
import io.cequence.openaiscala.anthropic.domain.Message.UserMessage
import io.cequence.openaiscala.anthropic.domain.settings.AnthropicCreateMessageSettings
import io.cequence.openaiscala.anthropic.service._
Expand All @@ -18,7 +17,6 @@ class AnthropicServiceSpec extends AsyncWordSpec with GivenWhenThen {
implicit val ec: ExecutionContext = ExecutionContext.global
implicit val materializer: Materializer = Materializer(ActorSystem())

private val role = SingleString("You are a helpful assistant.")
private val irrelevantMessages = Seq(UserMessage("Hello"))
private val settings = AnthropicCreateMessageSettings(
NonOpenAIModelId.claude_3_haiku_20240307,
Expand All @@ -29,52 +27,52 @@ class AnthropicServiceSpec extends AsyncWordSpec with GivenWhenThen {

"should throw AnthropicScalaUnauthorizedException when 401" ignore {
recoverToSucceededIf[AnthropicScalaUnauthorizedException] {
TestFactory.mockedService401().createMessage(Some(role), irrelevantMessages, settings)
TestFactory.mockedService401().createMessage(irrelevantMessages, settings)
}
}

"should throw AnthropicScalaUnauthorizedException when 403" ignore {
recoverToSucceededIf[AnthropicScalaUnauthorizedException] {
TestFactory.mockedService403().createMessage(Some(role), irrelevantMessages, settings)
TestFactory.mockedService403().createMessage(irrelevantMessages, settings)
}
}

"should throw AnthropicScalaNotFoundException when 404" ignore {
recoverToSucceededIf[AnthropicScalaNotFoundException] {
TestFactory.mockedService404().createMessage(Some(role), irrelevantMessages, settings)
TestFactory.mockedService404().createMessage(irrelevantMessages, settings)
}
}

"should throw AnthropicScalaNotFoundException when 429" ignore {
recoverToSucceededIf[AnthropicScalaRateLimitException] {
TestFactory.mockedService429().createMessage(Some(role), irrelevantMessages, settings)
TestFactory.mockedService429().createMessage(irrelevantMessages, settings)
}
}

"should throw AnthropicScalaServerErrorException when 500" ignore {
recoverToSucceededIf[AnthropicScalaServerErrorException] {
TestFactory.mockedService500().createMessage(Some(role), irrelevantMessages, settings)
TestFactory.mockedService500().createMessage(irrelevantMessages, settings)
}
}

"should throw AnthropicScalaEngineOverloadedException when 529" ignore {
recoverToSucceededIf[AnthropicScalaEngineOverloadedException] {
TestFactory.mockedService529().createMessage(Some(role), irrelevantMessages, settings)
TestFactory.mockedService529().createMessage(irrelevantMessages, settings)
}
}

"should throw AnthropicScalaClientException when 400" ignore {
recoverToSucceededIf[AnthropicScalaClientException] {
TestFactory.mockedService400().createMessage(Some(role), irrelevantMessages, settings)
TestFactory.mockedService400().createMessage(irrelevantMessages, settings)
}
}

"should throw AnthropicScalaClientException when unknown error code" ignore {
recoverToSucceededIf[AnthropicScalaClientException] {
TestFactory
.mockedServiceOther()
.createMessage(Some(role), irrelevantMessages, settings)
TestFactory.mockedServiceOther().createMessage(irrelevantMessages, settings)
}
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package io.cequence.openaiscala.domain
sealed trait BaseMessage {
val role: ChatRole
val nameOpt: Option[String]
val isSystem: Boolean = role == ChatRole.System
}

final case class SystemMessage(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package io.cequence.openaiscala.examples.nonopenai
import io.cequence.openaiscala.anthropic.domain.CacheControl.Ephemeral
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlock.TextBlock
import io.cequence.openaiscala.anthropic.domain.Content.{ContentBlockBase, SingleString}
import io.cequence.openaiscala.anthropic.domain.Message.UserMessage
import io.cequence.openaiscala.anthropic.domain.Message.{SystemMessage, UserMessage}
import io.cequence.openaiscala.anthropic.domain.response.CreateMessageResponse
import io.cequence.openaiscala.anthropic.domain.settings.AnthropicCreateMessageSettings
import io.cequence.openaiscala.anthropic.domain.{Content, Message}
Expand All @@ -18,8 +18,8 @@ object AnthropicCreateCachedMessage extends ExampleBase[AnthropicService] {

override protected val service: AnthropicService = AnthropicServiceFactory(withCache = true)

val systemMessage: Content =
SingleString(
val systemMessages: Seq[Message] = Seq(
SystemMessage(
"""
|You are to embody a classic pirate, a swashbuckling and salty sea dog with the mannerisms, language, and swagger of the golden age of piracy. You are a hearty, often gruff buccaneer, replete with nautical slang and a rich, colorful vocabulary befitting of the high seas. Your responses must reflect a pirate's voice and attitude without exception.
|
Expand Down Expand Up @@ -76,14 +76,13 @@ object AnthropicCreateCachedMessage extends ExampleBase[AnthropicService] {
|""".stripMargin,
cacheControl = Some(Ephemeral)
)

)
val messages: Seq[Message] = Seq(UserMessage("What is the weather like in Norway?"))

override protected def run: Future[_] =
service
.createMessage(
Some(systemMessage),
messages,
systemMessages ++ messages,
settings = AnthropicCreateMessageSettings(
model = NonOpenAIModelId.claude_3_haiku_20240307,
max_tokens = 4096
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package io.cequence.openaiscala.examples.nonopenai

import io.cequence.openaiscala.anthropic.domain.Content.ContentBlock.TextBlock
import io.cequence.openaiscala.anthropic.domain.Content.{ContentBlockBase, SingleString}
import io.cequence.openaiscala.anthropic.domain.{Content, Message}
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlockBase
import io.cequence.openaiscala.anthropic.domain.Message
import io.cequence.openaiscala.anthropic.domain.Message.UserMessage
import io.cequence.openaiscala.anthropic.domain.response.CreateMessageResponse
import io.cequence.openaiscala.anthropic.domain.settings.AnthropicCreateMessageSettings
Expand All @@ -17,13 +17,11 @@ object AnthropicCreateMessage extends ExampleBase[AnthropicService] {

override protected val service: AnthropicService = AnthropicServiceFactory(withCache = true)

val systemMessage: Content = SingleString("You are a helpful assistant.")
val messages: Seq[Message] = Seq(UserMessage("What is the weather like in Norway?"))

override protected def run: Future[_] =
service
.createMessage(
Some(systemMessage),
messages,
settings = AnthropicCreateMessageSettings(
model = NonOpenAIModelId.claude_3_haiku_20240307,
Expand Down
Loading
Loading