Skip to content

Commit

Permalink
new: handling auto-retry on rate limit exceeded for openai client
Browse files Browse the repository at this point in the history
  • Loading branch information
evilsocket committed Jun 23, 2024
1 parent 8f9c9d5 commit b14acc5
Showing 1 changed file with 58 additions and 6 deletions.
64 changes: 58 additions & 6 deletions src/agent/generator/openai.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
use std::time::Duration;

use colored::Colorize;
use duration_string::DurationString;
use openai_api_rust::chat::*;
use openai_api_rust::*;

use async_trait::async_trait;
use lazy_static::lazy_static;
use regex::Regex;

use super::{Client, Message, Options};

lazy_static! {
static ref RETRY_TIME_PARSER: Regex =
Regex::new(r"(?m)^.+try again in (.+)\. Visit.*").unwrap();
}

pub struct OpenAIClient {
model: String,
client: OpenAI,
Expand Down Expand Up @@ -64,14 +75,55 @@ impl Client for OpenAIClient {
user: None,
messages: chat_history,
};
let rs = self
.client
.chat_completion_create(&body)
.map_err(|e| anyhow!(e))?;
let resp = self.client.chat_completion_create(&body);

if let Err(error) = resp {
// if rate limit exceeded, parse the retry time and retry
if let Some(caps) = RETRY_TIME_PARSER
.captures_iter(&format!("{}", &error))
.next()
{
if caps.len() == 2 {
let mut retry_time_str = "".to_string();

caps.get(1)
.unwrap()
.as_str()
.clone_into(&mut retry_time_str);

// println!("{:?}", &rs);
// DurationString can't handle decimals like Xm3.838383s
if retry_time_str.contains('.') {
let (val, _) = retry_time_str.split_once('.').unwrap();
retry_time_str = format!("{}s", val);
}

if let Ok(retry_time) = retry_time_str.parse::<DurationString>() {
println!(
"{}: rate limit reached for this model, retrying in {} ...\n",
"WARNING".bold().yellow(),
retry_time,
);

tokio::time::sleep(
retry_time.checked_add(Duration::from_millis(1000)).unwrap(),
)
.await;

return self.chat(options).await;
} else {
eprintln!("can't parse '{}'", &retry_time_str);
}
} else {
eprintln!("cap len wrong");
}
} else {
eprintln!("regex failed");
}

return Err(anyhow!(error));
}

let choice = rs.choices;
let choice = resp.unwrap().choices;
let message = &choice[0].message.as_ref().unwrap();

Ok(message.content.to_string())
Expand Down

0 comments on commit b14acc5

Please sign in to comment.