Skip to content

Commit

Permalink
Changed Bedrock model to use Converse API and created Agent node
Browse files Browse the repository at this point in the history
  • Loading branch information
emersonmde committed Jun 23, 2024
1 parent c97b664 commit 5b7fa80
Show file tree
Hide file tree
Showing 17 changed files with 716 additions and 409 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@ categories = ["development-tools", "api-bindings", "asynchronous", "concurrency"
authors = ["Matthew Emerson <emersonmde@protonmail.com>"]

[features]
default = ["macros", "openai"]
full = ["macros", "tracing", "openai", "opensearch", "bedrock", "ollama"]
macros = ["anchor-chain-macros", "ctor"]
default = ["openai"]
full = ["tracing", "openai", "opensearch", "bedrock", "ollama"]
tracing = ["dep:tracing"]
openai = ["async-openai"]
opensearch = ["dep:opensearch", "aws-config"]
bedrock = ["aws-sdk-bedrockruntime", "aws-config"]
bedrock = ["aws-sdk-bedrockruntime", "aws-config", "aws-smithy-types"]
ollama = ["reqwest"]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand All @@ -39,14 +38,15 @@ tera = "1.20.0"
thiserror = "1.0.58"
tokio = { version = "1.36.0", features = ["full"] }
base64 = "0.22.0"
anchor-chain-macros = { path = "anchor-chain-macros" }
ctor = { version = "0.2.8" }
async-openai = { version = "0.23.2", optional = true }
tracing = { version = "0.1.40", optional = true }
reqwest = { version = "0.12.4", optional = true }
aws-config = { version = "1.5.1", features = ["behavior-version-latest"], optional = true }
aws-sdk-bedrockruntime = { version = "1.34.0", optional = true }
aws-smithy-types = { version = "1.2.0", optional = true }
opensearch = { version = "2.2.0", features = ["aws-auth"], optional = true }
anchor-chain-macros = { path = "anchor-chain-macros", optional = true}
ctor = { version = "0.2.8", optional = true }


[[example]]
Expand Down
35 changes: 22 additions & 13 deletions anchor-chain-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ fn try_tool(registry_name: syn::Ident, input: ItemFn) -> Result<TokenStream> {
if let Expr::Lit(expr) = &meta.value {
if let Lit::Str(lit_str) = &expr.lit {
let value = lit_str.value();
let trimed_value = value.trim();
if !trimed_value.is_empty() {
return Some(trimed_value.to_string());
let trimmed_value = value.trim();
if !trimmed_value.is_empty() {
return Some(trimmed_value.to_string());
}
}
}
Expand Down Expand Up @@ -77,12 +77,16 @@ fn try_tool(registry_name: syn::Ident, input: ItemFn) -> Result<TokenStream> {
.collect::<Vec<_>>();

let schema = json!({
"name": name,
"description": docs,
"input_schema": {
"type": "object",
"properties": properties,
"required": required
"toolSpec": {
"name": name,
"description": docs,
"inputSchema": {
"json": {
"type": "object",
"properties": properties,
"required": required
}
}
}
});

Expand All @@ -92,7 +96,7 @@ fn try_tool(registry_name: syn::Ident, input: ItemFn) -> Result<TokenStream> {
let struct_name = format_ident!("{}__AnchorChainTool", fn_name);
let register_fn_name = format_ident!("register_{}__anchor_chain_tool", fn_name);

// TODO: Move execute method to a stand alone function
// TODO: Move execute method to a stand-alone function
let expanded = quote! {
#input

Expand All @@ -115,7 +119,14 @@ fn try_tool(registry_name: syn::Ident, input: ItemFn) -> Result<TokenStream> {
#[doc(hidden)]
fn #register_fn_name() {
let schema_value: Value = serde_json::from_str(#schema_string).unwrap();
#registry_name.blocking_write().register_tool(stringify!(#fn_name), #struct_name::execute, schema_value);
#registry_name.blocking_write().register_tool(
anchor_chain::agents::tool_registry::ToolEntry::new(
stringify!(#fn_name),
#docs,
#struct_name::execute,
schema_value
)
);
}
};

Expand All @@ -131,8 +142,6 @@ fn extract_type(ty: &Type) -> Result<Value> {
"i32" | "i64" | "f32" | "f64" => Ok(json!({ "type": "number" })),
"bool" => Ok(json!({ "type": "boolean" })),
_ => {
// Check if it is a reference to &str
// TODO: Fix lifetime issue when deserializing &str
if type_name == "str" {
if let PathArguments::AngleBracketed(args) = &type_segment.arguments {
if args.args.is_empty() {
Expand Down
7 changes: 3 additions & 4 deletions examples/parallel_nodes.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use std::collections::HashMap;

use anchor_chain::{
to_boxed_future, ChainBuilder, Claude3Bedrock, OpenAIModel, ParallelNode, Prompt,
};
use anchor_chain::models::bedrock_converse::Claude3Bedrock;
use anchor_chain::{to_boxed_future, ChainBuilder, OpenAIModel, ParallelNode, Prompt};

#[tokio::main]
async fn main() {
let gpt3 = Box::new(OpenAIModel::new_gpt3_5_turbo("You are a helpful assistant").await);
let claude3 = Box::new(Claude3Bedrock::new("You are a helpful assistant").await);
let claude3 = Box::new(Claude3Bedrock::new().await);

let concat_fn = to_boxed_future(|outputs: Vec<String>| {
Ok(outputs
Expand Down
7 changes: 3 additions & 4 deletions examples/parallel_nodes_async.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use futures::future::BoxFuture;
use std::collections::HashMap;

use anchor_chain::{
nodes::prompt::Prompt, parallel_node::ParallelNode, ChainBuilder, Claude3Bedrock, OpenAIModel,
};
use anchor_chain::models::bedrock_converse::Claude3Bedrock;
use anchor_chain::{nodes::prompt::Prompt, parallel_node::ParallelNode, ChainBuilder, OpenAIModel};

#[tokio::main]
async fn main() {
let gpt3 = Box::new(OpenAIModel::new_gpt3_5_turbo("You are a helpful assistant").await);
let claude3 = Box::new(Claude3Bedrock::new("You are a helpful assistant").await);
let claude3 = Box::new(Claude3Bedrock::new().await);

let select_output_fn = Box::new(|outputs: Vec<String>| -> BoxFuture<Result<String, _>> {
Box::pin(async move {
Expand Down
87 changes: 46 additions & 41 deletions examples/tool_usage.rs
Original file line number Diff line number Diff line change
@@ -1,63 +1,68 @@
use anchor_chain::{ChainBuilder, Claude3Bedrock, Prompt, ToolRegistry};
use anchor_chain_macros::tool;
use std::ops::Deref;
use std::time::{SystemTime, UNIX_EPOCH};

use once_cell::sync::Lazy;
use serde_json::Value;
use std::collections::HashMap;
use tokio::sync::RwLock;

use anchor_chain::{AgentExecutor, ChainBuilder, ToolRegistry};
use anchor_chain_macros::tool;

static TOOL_REGISTRY: Lazy<RwLock<ToolRegistry>> = Lazy::new(|| RwLock::new(ToolRegistry::new()));

/// This is a foo function
/// Generates the current weather in Celsius
///
/// # Parameters
/// - None
///
/// # Returns
/// - A float representing the current temperature in Celsius.
#[tool(TOOL_REGISTRY)]
fn get_weather() -> f64 {
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
let seed = now.as_secs();

(seed % 40) as f64
}

/// Converts a temperature from Celsius to Fahrenheit
///
/// # Parameters
/// - `celsius`: A float representing the temperature in Celsius.
///
/// This is another line
/// # Returns
/// - A float representing the temperature in Fahrenheit.
#[tool(TOOL_REGISTRY)]
fn foo(one: String, two: String) {
println!("Foobar {one} {two}")
fn celsius_to_fahrenheit(celsius: f64) -> f64 {
celsius * 1.8 + 32.0
}

/// This is a bar function
/// Provides a common sentiment based on the temperature
///
/// # Parameters
/// - `temp`: A float representing the temperature in Fahrenheit.
///
/// # Returns
/// - A `String` with a sentiment on the temperature.
#[tool(TOOL_REGISTRY)]
fn bar(x: i32, y: i32) -> i32 {
x + y
fn weather_sentiment(temp: f64) -> String {
if temp > 85.0 {
format!("{temp} is hot").to_string()
} else if temp < 60.0 {
format!("{temp} is cold").to_string()
} else {
format!("{temp} is moderate").to_string()
}
}

#[tokio::main]
async fn main() {
let params = serde_json::json!({"one": "baz", "two": "bam"});
TOOL_REGISTRY
.read()
.await
.execute_tool("foo", params)
.unwrap();
println!(
"Foo schema: {:?}",
TOOL_REGISTRY.read().await.get_schema("foo").unwrap()
);

let params = serde_json::json!({"x": 1, "y": 2});
let result = TOOL_REGISTRY
.read()
.await
.execute_tool("bar", params)
.unwrap();
println!("Bar result: {}", result);
println!(
"Bar schema: {:?}",
TOOL_REGISTRY.read().await.get_schema("bar").unwrap()
);

let claude3 = Claude3Bedrock::new("You are a helpful assistant").await;

let chain = ChainBuilder::new()
.link(Prompt::new("{{ input }}"))
.link(claude3)
.link(AgentExecutor::new_claude_agent(TOOL_REGISTRY.deref()).await)
.build();

let output = chain
.process(HashMap::from([(
"input",
"Write a hello world program in Rust",
)]))
.process("Is it hot outside?".to_string())
.await
.expect("Error processing chain");
println!("{}", output);
Expand Down
121 changes: 121 additions & 0 deletions src/agents/agent_executor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
use std::collections::HashMap;

use async_trait::async_trait;
use aws_sdk_bedrockruntime::types::{ContentBlock, Message as BedrockMessage};
use aws_smithy_types::Document;
use tokio::sync::RwLock;

use crate::agents::tool_registry::{
convert_document_to_value, convert_value_to_document, ToolHandler,
};
use crate::models::bedrock_converse::BedrockModel;
use crate::node::Stateful;
use crate::{AnchorChainError, BedrockConverse, Node, StateManager, ToolRegistry};

#[derive(Debug)]
enum AgentModel {
Claude3_5(BedrockConverse<BedrockMessage>),
}

#[derive(Debug, anchor_chain_macros::Stateless)]
pub struct AgentExecutor<'a> {
llm: AgentModel,
max_iterations: usize,
tool_registry: &'a RwLock<ToolRegistry<'a>>,
}

impl<'a> AgentExecutor<'a> {
pub async fn new_claude_agent(tool_registry: &'a RwLock<ToolRegistry<'a>>) -> Self {
let mut llm = BedrockConverse::new_with_system_prompt(
BedrockModel::Claude3_5,
"You are a helpful assistant",
)
.await;
llm.set_tool_registry(tool_registry).await;
llm.set_state(StateManager::new()).await;
AgentExecutor {
llm: AgentModel::Claude3_5(llm),
max_iterations: 10,
tool_registry,
}
}

async fn run_claude_agent(
&self,
llm: &BedrockConverse<BedrockMessage>,
input: String,
) -> Result<String, AnchorChainError> {
let mut output = Vec::new();
let input = format!(
"Given the tools available, answer the users question: {}",
input
)
.to_string();

let mut response = llm.process(input.clone()).await?.content;
println!("Response: {response:?}");

// TODO: Move to custom Node
for _ in 0..self.max_iterations {
println!("Content: {response:?}\n");
let mut tool_used = false;
for content in response.clone() {
match content {
ContentBlock::Text(text) => output.push(text),
ContentBlock::ToolUse(tool_request) => {
tool_used = true;
// TODO: handle error
let tool_result = self
.tool_registry
.read()
.await
.execute_tool(
tool_request.name(),
convert_document_to_value(&tool_request.input),
)
.unwrap();
println!("Result from tool function: {:?}\n", tool_result);
let tool_response = llm
.invoke_with_tool_response(
tool_request.tool_use_id,
Document::Object(HashMap::from([(
"return".to_string(),
convert_value_to_document(&tool_result),
)])),
None,
)
.await;
println!(
"Response after sending back tool result: {:?}\n",
tool_response
);
if let Ok(content) = tool_response {
response = content.content
}
}
ContentBlock::Image(_) => unimplemented!("Received unexpected Image response"),
ContentBlock::ToolResult(_) => unreachable!("Received ToolResult from model"),
_ => unimplemented!("Unknown response received from model"),
}
}
if !tool_used {
break;
}
}
println!("Final output: {:?}", output);
println!("\n============\n\n");
Ok(output.join("\n\n"))
}
}

#[async_trait]
impl<'a> Node for AgentExecutor<'a> {
type Input = String;
type Output = String;

async fn process(&self, input: Self::Input) -> Result<Self::Output, AnchorChainError> {
match &self.llm {
AgentModel::Claude3_5(claude) => self.run_claude_agent(claude, input).await,
}
}
}
2 changes: 2 additions & 0 deletions src/agents/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
#[cfg(feature = "bedrock")]
pub mod agent_executor;
pub mod tool_registry;
Loading

0 comments on commit 5b7fa80

Please sign in to comment.