Skip to content

Commit

Permalink
new: embeddings generation capabilities implemented for openai and ol…
Browse files Browse the repository at this point in the history
…lama generators
  • Loading branch information
evilsocket committed Jun 25, 2024
1 parent 150ccd0 commit 3fcbbd2
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 110 deletions.
65 changes: 13 additions & 52 deletions src/agent/generator/groq.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
use std::time::Duration;

use anyhow::Result;
use async_trait::async_trait;
use colored::Colorize;
use duration_string::DurationString;
use groq_api_rs::completion::{client::Groq, request::builder, response::ErrorResponse};
use lazy_static::lazy_static;
use regex::Regex;

use crate::agent::generator::Message;

use super::{Client, Options};
use super::{Client, Embeddings, Options};

lazy_static! {
static ref RETRY_TIME_PARSER: Regex =
Expand All @@ -23,7 +20,7 @@ pub struct GroqClient {

#[async_trait]
impl Client for GroqClient {
fn new(_: &str, _: u16, model_name: &str, _: u32) -> anyhow::Result<Self>
fn new(_: &str, _: u16, model_name: &str, _: u32) -> Result<Self>
where
Self: Sized,
{
Expand All @@ -35,7 +32,7 @@ impl Client for GroqClient {
Ok(Self { model, api_key })
}

async fn chat(&self, options: &Options) -> anyhow::Result<String> {
async fn chat(&self, options: &Options) -> Result<String> {
let mut chat_history = vec![
groq_api_rs::completion::message::Message::SystemMessage {
role: Some("system".to_string()),
Expand Down Expand Up @@ -104,52 +101,11 @@ impl Client for GroqClient {
if let Some(err_resp) = error.downcast_ref::<ErrorResponse>() {
// if rate limit exceeded, parse the retry time and retry
if err_resp.code == 429 {
if let Some(caps) = RETRY_TIME_PARSER
.captures_iter(&err_resp.error.message)
.next()
{
if caps.len() == 2 {
let mut retry_time_str = "".to_string();

caps.get(1)
.unwrap()
.as_str()
.clone_into(&mut retry_time_str);

// DurationString can't handle decimals like Xm3.838383s
if retry_time_str.contains('.') {
let (val, _) = retry_time_str.split_once('.').unwrap();
retry_time_str = format!("{}s", val);
}

if let Ok(retry_time) = retry_time_str.parse::<DurationString>() {
println!(
"{}: rate limit reached for this model, retrying in {} ...\n",
"WARNING".bold().yellow(),
retry_time,
);

tokio::time::sleep(
retry_time.checked_add(Duration::from_millis(1000)).unwrap(),
)
.await;

return self.chat(options).await;
} else {
eprintln!("can't parse '{}'", &retry_time_str);
}
} else {
eprintln!("cap len wrong");
}
return if self.check_rate_limit(&err_resp.error.message).await {
self.chat(options).await
} else {
eprintln!("regex failed");
}

eprintln!(
"{}: can't parse retry time from error response: {:?}",
"WARNING".bold().yellow(),
&err_resp
);
Err(anyhow!(error))
};
}
}

Expand All @@ -167,4 +123,9 @@ impl Client for GroqClient {

Ok(choice.message.content.to_string())
}

async fn embeddings(&self, _text: &str) -> Result<Embeddings> {
// TODO: extend the rust client to do this
todo!("groq embeddings generation not yet implemented")
}
}
55 changes: 54 additions & 1 deletion src/agent/generator/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use std::fmt::Display;
use std::{fmt::Display, time::Duration};

use anyhow::Result;
use async_trait::async_trait;
use colored::Colorize;
use duration_string::DurationString;
use lazy_static::lazy_static;
use regex::Regex;

use super::Invocation;

Expand All @@ -12,6 +16,11 @@ mod ollama;
#[cfg(feature = "openai")]
mod openai;

lazy_static! {
static ref RETRY_TIME_PARSER: Regex =
Regex::new(r"(?m)^.+try again in (.+)\. Visit.*").unwrap();
}

#[derive(Clone, Debug)]
pub struct Options {
pub system_prompt: String,
Expand Down Expand Up @@ -48,13 +57,57 @@ impl Display for Message {
}
}

pub type Embeddings = Vec<f64>;

#[async_trait]
pub trait Client {
fn new(url: &str, port: u16, model_name: &str, context_window: u32) -> Result<Self>
where
Self: Sized;

async fn chat(&self, options: &Options) -> Result<String>;
async fn embeddings(&self, text: &str) -> Result<Embeddings>;

async fn check_rate_limit(&self, error: &str) -> bool {
// if rate limit exceeded, parse the retry time and retry
if let Some(caps) = RETRY_TIME_PARSER.captures_iter(error).next() {
if caps.len() == 2 {
let mut retry_time_str = "".to_string();

caps.get(1)
.unwrap()
.as_str()
.clone_into(&mut retry_time_str);

// DurationString can't handle decimals like Xm3.838383s
if retry_time_str.contains('.') {
let (val, _) = retry_time_str.split_once('.').unwrap();
retry_time_str = format!("{}s", val);
}

if let Ok(retry_time) = retry_time_str.parse::<DurationString>() {
println!(
"{}: rate limit reached for this model, retrying in {} ...\n",
"WARNING".bold().yellow(),
retry_time,
);

tokio::time::sleep(
retry_time.checked_add(Duration::from_millis(1000)).unwrap(),
)
.await;

return true;
} else {
eprintln!("can't parse '{}'", &retry_time_str);
}
} else {
eprintln!("cap len wrong");
}
}

return false;
}
}

pub fn factory(
Expand Down
13 changes: 11 additions & 2 deletions src/agent/generator/ollama.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use anyhow::Result;
use async_trait::async_trait;

use ollama_rs::{
generation::{
chat::{request::ChatMessageRequest, ChatMessage},
Expand All @@ -8,7 +8,7 @@ use ollama_rs::{
Ollama,
};

use super::{Client, Message, Options};
use super::{Client, Embeddings, Message, Options};

pub struct OllamaClient {
model: String,
Expand Down Expand Up @@ -86,4 +86,13 @@ impl Client for OllamaClient {
Ok("".to_string())
}
}

async fn embeddings(&self, text: &str) -> Result<Embeddings> {
let resp = self
.client
.generate_embeddings(self.model.to_string(), text.to_string(), None)
.await?;

Ok(Embeddings::from(resp.embeddings))
}
}
87 changes: 32 additions & 55 deletions src/agent/generator/openai.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
use std::time::Duration;

use colored::Colorize;
use duration_string::DurationString;
use anyhow::Result;
use async_trait::async_trait;
use openai_api_rust::chat::*;
use openai_api_rust::embeddings::EmbeddingsApi;
use openai_api_rust::*;

use async_trait::async_trait;
use lazy_static::lazy_static;
use regex::Regex;

use super::{Client, Message, Options};

lazy_static! {
static ref RETRY_TIME_PARSER: Regex =
Regex::new(r"(?m)^.+try again in (.+)\. Visit.*").unwrap();
}
use super::{Client, Embeddings, Message, Options};

pub struct OpenAIClient {
model: String,
Expand Down Expand Up @@ -78,52 +68,39 @@ impl Client for OpenAIClient {
let resp = self.client.chat_completion_create(&body);

if let Err(error) = resp {
// if rate limit exceeded, parse the retry time and retry
if let Some(caps) = RETRY_TIME_PARSER
.captures_iter(&format!("{}", &error))
.next()
{
if caps.len() == 2 {
let mut retry_time_str = "".to_string();

caps.get(1)
.unwrap()
.as_str()
.clone_into(&mut retry_time_str);

// DurationString can't handle decimals like Xm3.838383s
if retry_time_str.contains('.') {
let (val, _) = retry_time_str.split_once('.').unwrap();
retry_time_str = format!("{}s", val);
}

if let Ok(retry_time) = retry_time_str.parse::<DurationString>() {
println!(
"{}: rate limit reached for this model, retrying in {} ...\n",
"WARNING".bold().yellow(),
retry_time,
);

tokio::time::sleep(
retry_time.checked_add(Duration::from_millis(1000)).unwrap(),
)
.await;

return self.chat(options).await;
} else {
eprintln!("can't parse '{}'", &retry_time_str);
}
} else {
eprintln!("cap len wrong");
}
}

return Err(anyhow!(error));
return if self.check_rate_limit(&error.to_string()).await {
self.chat(options).await
} else {
Err(anyhow!(error))
};
}

let choice = resp.unwrap().choices;
let message = &choice[0].message.as_ref().unwrap();

Ok(message.content.to_string())
}

async fn embeddings(&self, text: &str) -> Result<Embeddings> {
let body = embeddings::EmbeddingsBody {
model: self.model.to_string(),
input: vec![text.to_string()],
user: None,
};
let resp = self.client.embeddings_create(&body);
if let Err(error) = resp {
return if self.check_rate_limit(&error.to_string()).await {
self.embeddings(text).await
} else {
Err(anyhow!(error))
};
}

let embeddings = resp.unwrap().data;
let embedding = embeddings.as_ref().unwrap().first().unwrap();

Ok(Embeddings::from(
embedding.embedding.as_ref().unwrap_or(&vec![]).clone(),
))
}
}

0 comments on commit 3fcbbd2

Please sign in to comment.