From 9f2012cbffba0f100cf2f0f72d6d52c036f9052d Mon Sep 17 00:00:00 2001 From: hansel Date: Wed, 17 Jul 2024 09:27:10 +0800 Subject: [PATCH 1/7] add parallel_tool_calls option to chat completion --- src/v1/chat_completion.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index 0a551aea..be3120bd 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -45,6 +45,8 @@ pub struct ChatCompletionRequest { #[serde(skip_serializing_if = "Option::is_none")] pub tools: Option>, #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + #[serde(skip_serializing_if = "Option::is_none")] #[serde(serialize_with = "serialize_tool_choice")] pub tool_choice: Option, } @@ -67,6 +69,7 @@ impl ChatCompletionRequest { user: None, seed: None, tools: None, + parallel_tool_calls: None, tool_choice: None, } } @@ -87,6 +90,7 @@ impl_builder_methods!( user: String, seed: i64, tools: Vec, + parallel_tool_calls: bool, tool_choice: ToolChoiceType ); From 55b4b785fe1b20a68e06144504ddc427971ee722 Mon Sep 17 00:00:00 2001 From: hansel Date: Wed, 17 Jul 2024 09:39:21 +0800 Subject: [PATCH 2/7] allow content to have empty text, which will not serialize it --- src/v1/chat_completion.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index be3120bd..b9353b6c 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -115,7 +115,13 @@ impl serde::Serialize for Content { S: serde::Serializer, { match *self { - Content::Text(ref text) => serializer.serialize_str(text), + Content::Text(ref text) => { + if text.is_empty() { + serializer.serialize_none() + } else { + serializer.serialize_str(text) + } + } Content::ImageUrl(ref image_url) => image_url.serialize(serializer), } } From 7e10e3a293218210db6513c96399e74c0ebe3df3 Mon Sep 17 00:00:00 2001 From: hansel Date: Wed, 17 Jul 2024 09:45:50 +0800 Subject: [PATCH 3/7] add tool_calls to ChatCompletionMessage --- README.md | 1 + examples/chat_completion.rs | 1 + examples/function_call.rs | 1 + examples/function_call_role.rs | 1 + examples/vision.rs | 1 + src/v1/chat_completion.rs | 2 ++ 6 files changed, 7 insertions(+) diff --git a/README.md b/README.md index c0e6809b..cf440d4b 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ let req = ChatCompletionRequest::new( role: chat_completion::MessageRole::user, content: chat_completion::Content::Text(String::from("What is bitcoin?")), name: None, + tool_calls: None, }], ); ``` diff --git a/examples/chat_completion.rs b/examples/chat_completion.rs index d53134f3..6ef50355 100644 --- a/examples/chat_completion.rs +++ b/examples/chat_completion.rs @@ -13,6 +13,7 @@ async fn main() -> Result<(), Box> { role: chat_completion::MessageRole::user, content: chat_completion::Content::Text(String::from("What is bitcoin?")), name: None, + tool_calls: None, }], ); diff --git a/examples/function_call.rs b/examples/function_call.rs index 27c63751..d2ce8cb2 100644 --- a/examples/function_call.rs +++ b/examples/function_call.rs @@ -34,6 +34,7 @@ async fn main() -> Result<(), Box> { role: chat_completion::MessageRole::user, content: chat_completion::Content::Text(String::from("What is the price of Ethereum?")), name: None, + tool_calls: None, }], ) .tools(vec![chat_completion::Tool { diff --git a/examples/function_call_role.rs b/examples/function_call_role.rs index dcdf7ad0..a6bf0645 100644 --- a/examples/function_call_role.rs +++ b/examples/function_call_role.rs @@ -34,6 +34,7 @@ async fn main() -> Result<(), Box> { role: chat_completion::MessageRole::user, content: chat_completion::Content::Text(String::from("What is the price of Ethereum?")), name: None, + tool_calls: None, }], ) .tools(vec![chat_completion::Tool { diff --git a/examples/vision.rs b/examples/vision.rs index b62a6535..056feefc 100644 --- a/examples/vision.rs +++ b/examples/vision.rs @@ -28,6 +28,7 @@ async fn main() -> Result<(), Box> { }, ]), name: None, + tool_calls: None, }], ); diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index b9353b6c..1536f9de 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -156,6 +156,8 @@ pub struct ChatCompletionMessage { pub content: Content, #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, } #[derive(Debug, Deserialize, Serialize)] From 4a11c7839372489416c4970931539709312f18e3 Mon Sep 17 00:00:00 2001 From: hansel Date: Wed, 17 Jul 2024 10:03:54 +0800 Subject: [PATCH 4/7] add tool role for ChatCompletionMessage --- README.md | 1 + examples/chat_completion.rs | 1 + examples/function_call.rs | 1 + examples/function_call_role.rs | 5 +++++ examples/vision.rs | 1 + src/v1/chat_completion.rs | 3 +++ 6 files changed, 12 insertions(+) diff --git a/README.md b/README.md index cf440d4b..fb088724 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ let req = ChatCompletionRequest::new( content: chat_completion::Content::Text(String::from("What is bitcoin?")), name: None, tool_calls: None, + tool_call_id: None, }], ); ``` diff --git a/examples/chat_completion.rs b/examples/chat_completion.rs index 6ef50355..7a312410 100644 --- a/examples/chat_completion.rs +++ b/examples/chat_completion.rs @@ -14,6 +14,7 @@ async fn main() -> Result<(), Box> { content: chat_completion::Content::Text(String::from("What is bitcoin?")), name: None, tool_calls: None, + tool_call_id: None, }], ); diff --git a/examples/function_call.rs b/examples/function_call.rs index d2ce8cb2..ddf1d3dd 100644 --- a/examples/function_call.rs +++ b/examples/function_call.rs @@ -35,6 +35,7 @@ async fn main() -> Result<(), Box> { content: chat_completion::Content::Text(String::from("What is the price of Ethereum?")), name: None, tool_calls: None, + tool_call_id: None, }], ) .tools(vec![chat_completion::Tool { diff --git a/examples/function_call_role.rs b/examples/function_call_role.rs index a6bf0645..38463a51 100644 --- a/examples/function_call_role.rs +++ b/examples/function_call_role.rs @@ -35,6 +35,7 @@ async fn main() -> Result<(), Box> { content: chat_completion::Content::Text(String::from("What is the price of Ethereum?")), name: None, tool_calls: None, + tool_call_id: None, }], ) .tools(vec![chat_completion::Tool { @@ -89,6 +90,8 @@ async fn main() -> Result<(), Box> { "What is the price of Ethereum?", )), name: None, + tool_calls: None, + tool_call_id: None, }, chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::function, @@ -97,6 +100,8 @@ async fn main() -> Result<(), Box> { format!("{{\"price\": {}}}", price) }), name: Some(String::from("get_coin_price")), + tool_calls: None, + tool_call_id: None, }, ], ); diff --git a/examples/vision.rs b/examples/vision.rs index 056feefc..57397ae0 100644 --- a/examples/vision.rs +++ b/examples/vision.rs @@ -29,6 +29,7 @@ async fn main() -> Result<(), Box> { ]), name: None, tool_calls: None, + tool_call_id: None, }], ); diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index 1536f9de..28284be9 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -101,6 +101,7 @@ pub enum MessageRole { system, assistant, function, + tool, } #[derive(Debug, Deserialize, Clone, PartialEq, Eq)] @@ -158,6 +159,8 @@ pub struct ChatCompletionMessage { pub name: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, } #[derive(Debug, Deserialize, Serialize)] From 787ad45d4b99b85b8a4a6582ad71091704282f8c Mon Sep 17 00:00:00 2001 From: hansel Date: Wed, 17 Jul 2024 14:44:45 +0800 Subject: [PATCH 5/7] deserialize --- src/v1/chat_completion.rs | 63 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 59 insertions(+), 4 deletions(-) diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index 28284be9..faa42950 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -1,11 +1,11 @@ use serde::ser::SerializeMap; -use serde::{Deserialize, Serialize, Serializer}; +use serde::{Deserialize, Serialize, Serializer, Deserializer}; use serde_json::Value; use std::collections::HashMap; - +use serde::de::{self, MapAccess, SeqAccess, Visitor}; use crate::impl_builder_methods; use crate::v1::common; - +use std::fmt; #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] pub enum ToolChoiceType { None, @@ -104,7 +104,7 @@ pub enum MessageRole { tool, } -#[derive(Debug, Deserialize, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum Content { Text(String), ImageUrl(Vec), @@ -128,6 +128,61 @@ impl serde::Serialize for Content { } } +impl<'de> Deserialize<'de> for Content { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct ContentVisitor; + + impl<'de> Visitor<'de> for ContentVisitor { + type Value = Content; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a valid content type") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + Ok(Content::Text(value.to_string())) + } + + fn visit_seq(self, seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let image_urls: Vec = Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?; + Ok(Content::ImageUrl(image_urls)) + } + + fn visit_map(self, map: M) -> Result + where + M: serde::de::MapAccess<'de>, + { + let image_urls: Vec = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?; + Ok(Content::ImageUrl(image_urls)) + } + + fn visit_none(self) -> Result + where + E: de::Error, + { + Ok(Content::Text(String::new())) + } + + fn visit_unit(self) -> Result + where + E: de::Error, + { + Ok(Content::Text(String::new())) + } + } + + deserializer.deserialize_any(ContentVisitor) + } +} #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] #[allow(non_camel_case_types)] pub enum ContentType { From c6e231bdf49a3f6cc35874a4ceeceb2d4d93451c Mon Sep 17 00:00:00 2001 From: hansel Date: Wed, 17 Jul 2024 14:48:09 +0800 Subject: [PATCH 6/7] fmt --- src/v1/chat_completion.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index faa42950..983b9355 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -1,10 +1,10 @@ +use crate::impl_builder_methods; +use crate::v1::common; +use serde::de::{self, MapAccess, SeqAccess, Visitor}; use serde::ser::SerializeMap; -use serde::{Deserialize, Serialize, Serializer, Deserializer}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_json::Value; use std::collections::HashMap; -use serde::de::{self, MapAccess, SeqAccess, Visitor}; -use crate::impl_builder_methods; -use crate::v1::common; use std::fmt; #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] pub enum ToolChoiceType { @@ -153,7 +153,8 @@ impl<'de> Deserialize<'de> for Content { where A: serde::de::SeqAccess<'de>, { - let image_urls: Vec = Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?; + let image_urls: Vec = + Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?; Ok(Content::ImageUrl(image_urls)) } @@ -161,7 +162,8 @@ impl<'de> Deserialize<'de> for Content { where M: serde::de::MapAccess<'de>, { - let image_urls: Vec = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?; + let image_urls: Vec = + Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?; Ok(Content::ImageUrl(image_urls)) } From a128d333b07194f4f384c384e6ceaa724f9478dd Mon Sep 17 00:00:00 2001 From: hansel Date: Wed, 17 Jul 2024 14:50:31 +0800 Subject: [PATCH 7/7] ref imports directly --- src/v1/chat_completion.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index 983b9355..b999b92c 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -151,7 +151,7 @@ impl<'de> Deserialize<'de> for Content { fn visit_seq(self, seq: A) -> Result where - A: serde::de::SeqAccess<'de>, + A: SeqAccess<'de>, { let image_urls: Vec = Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?; @@ -160,7 +160,7 @@ impl<'de> Deserialize<'de> for Content { fn visit_map(self, map: M) -> Result where - M: serde::de::MapAccess<'de>, + M: MapAccess<'de>, { let image_urls: Vec = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;