diff --git a/src/agent/generator/openai.rs b/src/agent/generator/openai.rs index fed5c05..70b2a5e 100644 --- a/src/agent/generator/openai.rs +++ b/src/agent/generator/openai.rs @@ -1,10 +1,21 @@ +use std::time::Duration; + +use colored::Colorize; +use duration_string::DurationString; use openai_api_rust::chat::*; use openai_api_rust::*; use async_trait::async_trait; +use lazy_static::lazy_static; +use regex::Regex; use super::{Client, Message, Options}; +lazy_static! { + static ref RETRY_TIME_PARSER: Regex = + Regex::new(r"(?m)^.+try again in (.+)\. Visit.*").unwrap(); +} + pub struct OpenAIClient { model: String, client: OpenAI, @@ -64,14 +75,55 @@ impl Client for OpenAIClient { user: None, messages: chat_history, }; - let rs = self - .client - .chat_completion_create(&body) - .map_err(|e| anyhow!(e))?; + let resp = self.client.chat_completion_create(&body); + + if let Err(error) = resp { + // if rate limit exceeded, parse the retry time and retry + if let Some(caps) = RETRY_TIME_PARSER + .captures_iter(&format!("{}", &error)) + .next() + { + if caps.len() == 2 { + let mut retry_time_str = "".to_string(); + + caps.get(1) + .unwrap() + .as_str() + .clone_into(&mut retry_time_str); - // println!("{:?}", &rs); + // DurationString can't handle decimals like Xm3.838383s + if retry_time_str.contains('.') { + let (val, _) = retry_time_str.split_once('.').unwrap(); + retry_time_str = format!("{}s", val); + } + + if let Ok(retry_time) = retry_time_str.parse::() { + println!( + "{}: rate limit reached for this model, retrying in {} ...\n", + "WARNING".bold().yellow(), + retry_time, + ); + + tokio::time::sleep( + retry_time.checked_add(Duration::from_millis(1000)).unwrap(), + ) + .await; + + return self.chat(options).await; + } else { + eprintln!("can't parse '{}'", &retry_time_str); + } + } else { + eprintln!("cap len wrong"); + } + } else { + eprintln!("regex failed"); + } + + return Err(anyhow!(error)); + } - let choice = rs.choices; + let choice = resp.unwrap().choices; let message = &choice[0].message.as_ref().unwrap(); Ok(message.content.to_string())