Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/bin/hook.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ impl Args {
.clone()
.unwrap_or("gpt-4o-mini".to_string())
.into();
let used_tokens = commit::token_used(&model)?;
let used_tokens = commit::calculate_token_usage(&model)?;
let max_tokens = config::APP_CONFIG
.max_tokens
.unwrap_or(model.context_size());
Expand Down
2 changes: 1 addition & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub async fn call(request: Request) -> Result<Response> {
model: request.model
};

let response = openai::call(openai_request).await?;
let response = openai::generate_with_openai(openai_request).await?;
Ok(Response { response: response.response })
}

Expand Down
18 changes: 9 additions & 9 deletions src/commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ use crate::multi_step_integration::{generate_commit_message_local, generate_comm
/// The instruction template included at compile time
const INSTRUCTION_TEMPLATE: &str = include_str!("../resources/prompt.md");

/// Returns the instruction template for the AI model.
/// Generates the instruction template for the AI model.
/// This template guides the model in generating appropriate commit messages.
///
/// # Returns
/// * `Result<String>` - The rendered template or an error
///
/// Note: This function is public only for testing purposes
#[doc(hidden)]
pub fn get_instruction_template() -> Result<String> {
pub fn generate_instruction_template() -> Result<String> {
profile!("Generate instruction template");
let max_length = config::APP_CONFIG
.max_commit_length
Expand Down Expand Up @@ -48,7 +48,7 @@ pub fn get_instruction_template() -> Result<String> {
#[doc(hidden)]
pub fn create_commit_request(diff: String, max_tokens: usize, model: Model) -> Result<openai::Request> {
profile!("Prepare OpenAI request");
let template = get_instruction_template()?;
let template = generate_instruction_template()?;
Ok(openai::Request {
system: template,
prompt: diff,
Expand Down Expand Up @@ -183,19 +183,19 @@ pub async fn generate(patch: String, remaining_tokens: usize, model: Model, sett
Some(custom_settings) => {
// Create a client with custom settings
match openai::create_openai_config(custom_settings) {
Ok(config) => openai::call_with_config(request, config).await,
Ok(config) => openai::generate_with_config(request, config).await,
Err(e) => Err(e)
}
}
None => {
// Use the default global config
openai::call(request).await
openai::generate_with_openai(request).await
}
}
}

pub fn token_used(model: &Model) -> Result<usize> {
get_instruction_token_count(model)
pub fn calculate_token_usage(model: &Model) -> Result<usize> {
calculate_instruction_token_usage(model)
}

/// Calculates the number of tokens used by the instruction template.
Expand All @@ -205,9 +205,9 @@ pub fn token_used(model: &Model) -> Result<usize> {
///
/// # Returns
/// * `Result<usize>` - The number of tokens used or an error
pub fn get_instruction_token_count(model: &Model) -> Result<usize> {
pub fn calculate_instruction_token_usage(model: &Model) -> Result<usize> {
profile!("Calculate instruction tokens");
let template = get_instruction_template()?;
let template = generate_instruction_template()?;
model.count_tokens(&template)
}

Expand Down
61 changes: 52 additions & 9 deletions src/multi_step_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub async fn generate_commit_message_multi_step(
let start_time = std::time::Instant::now();
let payload = format!("{{\"file_path\": \"{file_path}\", \"operation_type\": \"{operation}\", \"diff_content\": \"...\"}}");

let result = call_analyze_function(client, model, file).await;
let result = analyze_file_via_api(client, model, file).await;
let duration = start_time.elapsed();
(file, result, duration, payload)
}
Expand Down Expand Up @@ -135,7 +135,7 @@ pub async fn generate_commit_message_multi_step(

// Start step 2 and 3 in parallel
// First create the futures for both operations
let score_future = call_score_function(client, model, files_data);
let score_future = calculate_scores_via_api(client, model, files_data);

// Run the scoring operation
let scored_files = score_future.await?;
Expand All @@ -151,7 +151,7 @@ pub async fn generate_commit_message_multi_step(
let generate_payload = format!("{{\"files_with_scores\": [...], \"max_length\": {}}}", max_length.unwrap_or(72));

// Now create and run the generate and select steps in parallel
let generate_future = call_generate_function(client, model, scored_files.clone(), max_length.unwrap_or(72));
let generate_future = generate_candidates_via_api(client, model, scored_files.clone(), max_length.unwrap_or(72));

let candidates = generate_future.await?;
let generate_duration = generate_start_time.elapsed();
Expand Down Expand Up @@ -401,8 +401,22 @@ pub fn parse_diff(diff_content: &str) -> Result<Vec<ParsedFile>> {
Ok(files)
}

/// Call the analyze function via OpenAI
async fn call_analyze_function(client: &Client<OpenAIConfig>, model: &str, file: &ParsedFile) -> Result<Value> {
/// Analyze file via OpenAI API using function calling to extract structured data
///
/// # Arguments
/// * `client` - OpenAI client configured with API credentials
/// * `model` - AI model name to use for analysis
/// * `file` - Parsed file containing path, operation type, and diff content
///
/// # Returns
/// * `Result<Value>` - JSON value containing file analysis (lines added/removed, category, summary)
///
/// # Errors
/// Returns error if:
/// - OpenAI API call fails
/// - Model doesn't respond with expected function call format
/// - JSON parsing of response fails
async fn analyze_file_via_api(client: &Client<OpenAIConfig>, model: &str, file: &ParsedFile) -> Result<Value> {
let tools = vec![create_analyze_function_tool()?];

let system_message = ChatCompletionRequestSystemMessageArgs::default()
Expand Down Expand Up @@ -440,8 +454,22 @@ async fn call_analyze_function(client: &Client<OpenAIConfig>, model: &str, file:
}
}

/// Call the score function via OpenAI
async fn call_score_function(
/// Calculate impact scores via OpenAI API using analyzed file data
///
/// # Arguments
/// * `client` - OpenAI client configured with API credentials
/// * `model` - AI model name to use for scoring
/// * `files_data` - Vector of analyzed file data with categories and summaries
///
/// # Returns
/// * `Result<Vec<FileWithScore>>` - Files with calculated impact scores (0.0 to 1.0)
///
/// # Errors
/// Returns error if:
/// - OpenAI API call fails
/// - Model doesn't respond with expected function call format
/// - JSON parsing of response fails
async fn calculate_scores_via_api(
client: &Client<OpenAIConfig>, model: &str, files_data: Vec<FileDataForScoring>
) -> Result<Vec<FileWithScore>> {
let tools = vec![create_score_function_tool()?];
Expand Down Expand Up @@ -487,8 +515,23 @@ async fn call_score_function(
}
}

/// Call the generate function via OpenAI
async fn call_generate_function(
/// Generate commit message candidates via OpenAI API using scored files
///
/// # Arguments
/// * `client` - OpenAI client configured with API credentials
/// * `model` - AI model name to use for generation
/// * `files_with_scores` - Vector of files with calculated impact scores
/// * `max_length` - Maximum length for generated commit messages
///
/// # Returns
/// * `Result<Value>` - JSON value containing multiple commit message candidates with reasoning
///
/// # Errors
/// Returns error if:
/// - OpenAI API call fails
/// - Model doesn't respond with expected function call format
/// - JSON parsing of response fails
async fn generate_candidates_via_api(
client: &Client<OpenAIConfig>, model: &str, files_with_scores: Vec<FileWithScore>, max_length: usize
) -> Result<Value> {
let tools = vec![create_generate_function_tool()?];
Expand Down
43 changes: 37 additions & 6 deletions src/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,24 @@ fn truncate_to_fit(text: &str, max_tokens: usize, model: &Model) -> Result<Strin
}
}

/// Calls the OpenAI API with the provided configuration
pub async fn call_with_config(request: Request, config: OpenAIConfig) -> Result<Response> {
/// Generate commit message with OpenAI using provided configuration
///
/// Uses multi-step approach by default with fallback to single-step generation.
/// Includes token management, timeout handling, and retry logic.
///
/// # Arguments
/// * `request` - OpenAI request containing system prompt, user prompt, model, and token limits
/// * `config` - OpenAI configuration with API key and other settings
///
/// # Returns
/// * `Result<Response>` - Generated commit message response
///
/// # Errors
/// Returns error if:
/// - API key is invalid or missing
/// - All generation attempts fail (multi-step and single-step)
/// - Network or API communication errors occur
pub async fn generate_with_config(request: Request, config: OpenAIConfig) -> Result<Response> {
profile!("OpenAI API call with custom config");

// Always try multi-step approach first (it's now the default)
Expand Down Expand Up @@ -377,13 +393,28 @@ pub async fn call_with_config(request: Request, config: OpenAIConfig) -> Result<
}
}

/// Calls the OpenAI API with default configuration from settings
pub async fn call(request: Request) -> Result<Response> {
/// Generate commit message with OpenAI using default configuration from settings
///
/// Convenience function that creates OpenAI configuration from global app settings
/// and delegates to `generate_with_config`.
///
/// # Arguments
/// * `request` - OpenAI request containing system prompt, user prompt, model, and token limits
///
/// # Returns
/// * `Result<Response>` - Generated commit message response
///
/// # Errors
/// Returns error if:
/// - Global configuration is invalid or missing API key
/// - All generation attempts fail (multi-step and single-step)
/// - Network or API communication errors occur
pub async fn generate_with_openai(request: Request) -> Result<Response> {
profile!("OpenAI API call");

// Create OpenAI configuration using our settings
let config = create_openai_config(&config::APP_CONFIG)?;

// Use the call_with_config function with the default config
call_with_config(request, config).await
// Use the generate_with_config function with the default config
generate_with_config(request, config).await
}
14 changes: 7 additions & 7 deletions tests/llm_input_generation_test.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use ai::commit::{create_commit_request, get_instruction_template, token_used};
use ai::commit::{calculate_token_usage, create_commit_request, generate_instruction_template};
use ai::model::Model;

/// Tests for the LLM input generation system
Expand All @@ -12,7 +12,7 @@ use ai::model::Model;
#[test]
fn test_template_generation_with_default_max_length() {
// Test that template generation works with default config
let result = get_instruction_template();
let result = generate_instruction_template();
assert!(result.is_ok(), "Template generation should succeed");

let template = result.unwrap();
Expand All @@ -38,7 +38,7 @@ fn test_token_counting_empty_template() {
fn test_token_counting_template() {
// Test that we can count tokens in the actual template
let model = Model::GPT41Mini;
let result = token_used(&model);
let result = calculate_token_usage(&model);

assert!(result.is_ok(), "Token counting should succeed");
let token_count = result.unwrap();
Expand Down Expand Up @@ -498,7 +498,7 @@ index 777..888 100644

#[test]
fn test_template_contains_required_sections() {
let template = get_instruction_template().unwrap();
let template = generate_instruction_template().unwrap();

// Verify template has all required sections for the LLM
let required_sections = vec![
Expand Down Expand Up @@ -547,8 +547,8 @@ index 123abc..456def 100644

// Test the full workflow
let model = Model::GPT41Mini;
let template = get_instruction_template().unwrap();
let token_count = token_used(&model).unwrap();
let template = generate_instruction_template().unwrap();
let token_count = calculate_token_usage(&model).unwrap();
let request = create_commit_request(simple_diff.clone(), 2000, model).unwrap();

// Verify all components work together
Expand Down Expand Up @@ -581,7 +581,7 @@ index abc..def 100644
.to_string();

// Calculate total tokens needed
let template_tokens = token_used(&model).unwrap();
let template_tokens = calculate_token_usage(&model).unwrap();
let diff_tokens = model.count_tokens(&diff).unwrap();
let total_input_tokens = template_tokens + diff_tokens;

Expand Down