diff --git a/src/commit.rs b/src/commit.rs index e3246629..efcf6032 100644 --- a/src/commit.rs +++ b/src/commit.rs @@ -234,7 +234,7 @@ mod tests { let result = generate( "diff --git a/test.txt b/test.txt\n+Hello World".to_string(), 1024, - Model::GPT4oMini, + Model::GPT41Mini, Some(&settings) ) .await; @@ -268,7 +268,7 @@ mod tests { let result = generate( "diff --git a/test.txt b/test.txt\n+Hello World".to_string(), 1024, - Model::GPT4oMini, + Model::GPT41Mini, Some(&settings) ) .await; diff --git a/src/config.rs b/src/config.rs index 38c99ac1..aa65ed88 100644 --- a/src/config.rs +++ b/src/config.rs @@ -12,7 +12,7 @@ use console::Emoji; const DEFAULT_TIMEOUT: i64 = 30; const DEFAULT_MAX_COMMIT_LENGTH: i64 = 72; const DEFAULT_MAX_TOKENS: i64 = 2024; -const DEFAULT_MODEL: &str = "gpt-4o-mini"; +const DEFAULT_MODEL: &str = "gpt-4.1"; // Matches Model::default() const DEFAULT_API_KEY: &str = ""; #[derive(Debug, Default, Deserialize, PartialEq, Eq, Serialize)] diff --git a/src/model.rs b/src/model.rs index 5bbd1f22..4006b286 100644 --- a/src/model.rs +++ b/src/model.rs @@ -18,10 +18,10 @@ use crate::config::AppConfig; static TOKENIZER: OnceLock = OnceLock::new(); // Model identifiers - using screaming case for constants -const MODEL_GPT4: &str = "gpt-4"; -const MODEL_GPT4_OPTIMIZED: &str = "gpt-4o"; -const MODEL_GPT4_MINI: &str = "gpt-4o-mini"; const MODEL_GPT4_1: &str = "gpt-4.1"; +const MODEL_GPT4_1_MINI: &str = "gpt-4.1-mini"; +const MODEL_GPT4_1_NANO: &str = "gpt-4.1-nano"; +const MODEL_GPT4_5: &str = "gpt-4.5"; // TODO: Get this from config.rs or a shared constants module const DEFAULT_MODEL_NAME: &str = "gpt-4.1"; @@ -29,15 +29,15 @@ const DEFAULT_MODEL_NAME: &str = "gpt-4.1"; /// Each model has different capabilities and token limits. #[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, Serialize, Deserialize, Default)] pub enum Model { - /// Standard GPT-4 model - GPT4, - /// Optimized GPT-4 model for better performance - GPT4o, - /// Mini version of optimized GPT-4 for faster processing - GPT4oMini, /// Default model - GPT-4.1 latest version #[default] - GPT41 + GPT41, + /// Mini version of GPT-4.1 for faster processing + GPT41Mini, + /// Nano version of GPT-4.1 for very fast processing + GPT41Nano, + /// GPT-4.5 model for advanced capabilities + GPT45 } impl Model { @@ -59,10 +59,7 @@ impl Model { // Always use the proper tokenizer for accurate counts // We cannot afford to underestimate tokens as it may cause API failures - let tokenizer = TOKENIZER.get_or_init(|| { - let model_str: &str = self.into(); - get_tokenizer(model_str) - }); + let tokenizer = TOKENIZER.get_or_init(|| get_tokenizer(self.as_ref())); // Use direct tokenization for accurate token count let tokens = tokenizer.encode_ordinary(text); @@ -75,8 +72,7 @@ impl Model { /// * `usize` - The maximum number of tokens the model can process pub fn context_size(&self) -> usize { profile!("Get context size"); - let model_str: &str = self.into(); - get_context_size(model_str) + get_context_size(self.as_ref()) } /// Truncates the given text to fit within the specified token limit. @@ -167,41 +163,80 @@ impl Model { } } -impl From<&Model> for &str { - fn from(model: &Model) -> Self { - match model { - Model::GPT4o => MODEL_GPT4_OPTIMIZED, - Model::GPT4 => MODEL_GPT4, - Model::GPT4oMini => MODEL_GPT4_MINI, - Model::GPT41 => MODEL_GPT4_1 +impl AsRef for Model { + fn as_ref(&self) -> &str { + match self { + Model::GPT41 => MODEL_GPT4_1, + Model::GPT41Mini => MODEL_GPT4_1_MINI, + Model::GPT41Nano => MODEL_GPT4_1_NANO, + Model::GPT45 => MODEL_GPT4_5 } } } +// Keep conversion to String for cases that need owned strings +impl From<&Model> for String { + fn from(model: &Model) -> Self { + model.as_ref().to_string() + } +} + +// Keep the old impl for backwards compatibility where possible +impl Model { + pub fn as_str(&self) -> &str { + self.as_ref() + } +} + impl FromStr for Model { type Err = anyhow::Error; fn from_str(s: &str) -> Result { - match s.trim().to_lowercase().as_str() { - "gpt-4o" => Ok(Model::GPT4o), - "gpt-4" => Ok(Model::GPT4), - "gpt-4o-mini" => Ok(Model::GPT4oMini), + let normalized = s.trim().to_lowercase(); + match normalized.as_str() { "gpt-4.1" => Ok(Model::GPT41), - model => bail!("Invalid model name: {}", model) + "gpt-4.1-mini" => Ok(Model::GPT41Mini), + "gpt-4.1-nano" => Ok(Model::GPT41Nano), + "gpt-4.5" => Ok(Model::GPT45), + // Backward compatibility for deprecated models - map to closest GPT-4.1 equivalent + "gpt-4" | "gpt-4o" => { + log::warn!( + "Model '{}' is deprecated. Mapping to 'gpt-4.1'. \ + Please update your configuration with: git ai config set model gpt-4.1", + s + ); + Ok(Model::GPT41) + } + "gpt-4o-mini" | "gpt-3.5-turbo" => { + log::warn!( + "Model '{}' is deprecated. Mapping to 'gpt-4.1-mini'. \ + Please update your configuration with: git ai config set model gpt-4.1-mini", + s + ); + Ok(Model::GPT41Mini) + } + model => + bail!( + "Invalid model name: '{}'. Supported models: gpt-4.1, gpt-4.1-mini, gpt-4.1-nano, gpt-4.5", + model + ), } } } impl Display for Model { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", <&str>::from(self)) + write!(f, "{}", self.as_ref()) } } // Implement conversion from string types to Model with fallback to default impl From<&str> for Model { fn from(s: &str) -> Self { - s.parse().unwrap_or_default() + s.parse().unwrap_or_else(|e| { + log::error!("Failed to parse model '{}': {}. Falling back to default model 'gpt-4.1'.", s, e); + Model::default() + }) } } diff --git a/src/openai.rs b/src/openai.rs index a25b5532..8125e834 100644 --- a/src/openai.rs +++ b/src/openai.rs @@ -37,7 +37,7 @@ pub async fn generate_commit_message(diff: &str) -> Result { if let Ok(api_key) = std::env::var("OPENAI_API_KEY") { if !api_key.is_empty() { // Use the commit function directly without parsing - match commit::generate(diff.to_string(), 256, Model::GPT4oMini, None).await { + match commit::generate(diff.to_string(), 256, Model::GPT41Mini, None).await { Ok(response) => return Ok(response.response.trim().to_string()), Err(e) => { log::warn!("Direct generation failed, falling back to local: {e}"); diff --git a/tests/llm_input_generation_test.rs b/tests/llm_input_generation_test.rs index ad08eb51..d1c2f117 100644 --- a/tests/llm_input_generation_test.rs +++ b/tests/llm_input_generation_test.rs @@ -28,7 +28,7 @@ fn test_template_generation_with_default_max_length() { #[test] fn test_token_counting_empty_template() { // Token counting should work even with minimal content - let model = Model::GPT4oMini; + let model = Model::GPT41Mini; let result = model.count_tokens(""); assert!(result.is_ok(), "Should handle empty string"); assert_eq!(result.unwrap(), 0, "Empty string should have 0 tokens"); @@ -37,7 +37,7 @@ fn test_token_counting_empty_template() { #[test] fn test_token_counting_template() { // Test that we can count tokens in the actual template - let model = Model::GPT4oMini; + let model = Model::GPT41Mini; let result = token_used(&model); assert!(result.is_ok(), "Token counting should succeed"); @@ -52,7 +52,7 @@ fn test_token_counting_template() { fn test_create_request_with_zero_tokens() { // Edge case: what happens with 0 max_tokens? let diff = "diff --git a/test.txt b/test.txt\n+Hello World".to_string(); - let result = create_commit_request(diff, 0, Model::GPT4oMini); + let result = create_commit_request(diff, 0, Model::GPT41Mini); assert!(result.is_ok(), "Should create request even with 0 tokens"); let request = result.unwrap(); @@ -63,7 +63,7 @@ fn test_create_request_with_zero_tokens() { fn test_create_request_with_empty_diff() { // Corner case: empty diff let diff = "".to_string(); - let result = create_commit_request(diff.clone(), 1000, Model::GPT4oMini); + let result = create_commit_request(diff.clone(), 1000, Model::GPT41Mini); assert!(result.is_ok(), "Should handle empty diff"); let request = result.unwrap(); @@ -75,7 +75,7 @@ fn test_create_request_with_empty_diff() { fn test_create_request_with_whitespace_only_diff() { // Corner case: whitespace-only diff let diff = " \n\t\n ".to_string(); - let result = create_commit_request(diff.clone(), 1000, Model::GPT4oMini); + let result = create_commit_request(diff.clone(), 1000, Model::GPT41Mini); assert!(result.is_ok(), "Should handle whitespace-only diff"); let request = result.unwrap(); @@ -86,7 +86,7 @@ fn test_create_request_with_whitespace_only_diff() { fn test_create_request_preserves_model() { // Test that different models are preserved correctly let diff = "diff --git a/test.txt b/test.txt\n+Test".to_string(); - let models = vec![Model::GPT4oMini, Model::GPT4o, Model::GPT4, Model::GPT41]; + let models = vec![Model::GPT41Mini, Model::GPT45, Model::GPT41, Model::GPT41Nano]; for model in models { let result = create_commit_request(diff.clone(), 1000, model); @@ -103,7 +103,7 @@ fn test_create_request_with_max_u16_tokens() { let diff = "diff --git a/test.txt b/test.txt\n+Test".to_string(); let max_tokens = usize::from(u16::MAX); - let result = create_commit_request(diff, max_tokens, Model::GPT4oMini); + let result = create_commit_request(diff, max_tokens, Model::GPT41Mini); assert!(result.is_ok(), "Should handle max u16 tokens"); let request = result.unwrap(); @@ -116,7 +116,7 @@ fn test_create_request_with_overflow_tokens() { let diff = "diff --git a/test.txt b/test.txt\n+Test".to_string(); let max_tokens = usize::from(u16::MAX) + 1000; - let result = create_commit_request(diff, max_tokens, Model::GPT4oMini); + let result = create_commit_request(diff, max_tokens, Model::GPT41Mini); assert!(result.is_ok(), "Should handle token overflow"); let request = result.unwrap(); @@ -139,7 +139,7 @@ index 1234567..abcdefg 100644 "# .to_string(); - let result = create_commit_request(diff.clone(), 1000, Model::GPT4oMini); + let result = create_commit_request(diff.clone(), 1000, Model::GPT41Mini); assert!(result.is_ok(), "Should handle simple diff"); let request = result.unwrap(); @@ -164,7 +164,7 @@ index 0000000..1234567 "# .to_string(); - let result = create_commit_request(diff.clone(), 2000, Model::GPT4o); + let result = create_commit_request(diff.clone(), 2000, Model::GPT45); assert!(result.is_ok(), "Should handle file addition"); let request = result.unwrap(); @@ -186,7 +186,7 @@ index 1234567..0000000 "# .to_string(); - let result = create_commit_request(diff.clone(), 1500, Model::GPT4oMini); + let result = create_commit_request(diff.clone(), 1500, Model::GPT41Mini); assert!(result.is_ok(), "Should handle file deletion"); let request = result.unwrap(); @@ -208,7 +208,7 @@ index 1234567..abcdefg 100644 "# .to_string(); - let result = create_commit_request(diff.clone(), 1000, Model::GPT4oMini); + let result = create_commit_request(diff.clone(), 1000, Model::GPT41Mini); assert!(result.is_ok(), "Should handle file rename"); let request = result.unwrap(); @@ -218,7 +218,7 @@ index 1234567..abcdefg 100644 #[test] fn test_token_counting_with_diff_content() { - let model = Model::GPT4oMini; + let model = Model::GPT41Mini; let small_diff = "diff --git a/a.txt b/a.txt\n+Hello"; let medium_diff = r#"diff --git a/test.js b/test.js @@ -275,7 +275,7 @@ index 0000000..5555555 "# .to_string(); - let result = create_commit_request(diff.clone(), 3000, Model::GPT4o); + let result = create_commit_request(diff.clone(), 3000, Model::GPT45); assert!(result.is_ok(), "Should handle multiple file changes"); let request = result.unwrap(); @@ -292,7 +292,7 @@ Binary files a/image.png and b/image.png differ "# .to_string(); - let result = create_commit_request(diff.clone(), 1000, Model::GPT4oMini); + let result = create_commit_request(diff.clone(), 1000, Model::GPT41Mini); assert!(result.is_ok(), "Should handle binary file diff"); let request = result.unwrap(); @@ -313,7 +313,7 @@ index 1234567..abcdefg 100644 "# .to_string(); - let result = create_commit_request(diff.clone(), 2000, Model::GPT4oMini); + let result = create_commit_request(diff.clone(), 2000, Model::GPT41Mini); assert!(result.is_ok(), "Should handle special characters"); let request = result.unwrap(); @@ -335,14 +335,14 @@ fn test_create_request_with_large_diff() { } } - let result = create_commit_request(diff.clone(), 8000, Model::GPT4o); + let result = create_commit_request(diff.clone(), 8000, Model::GPT45); assert!(result.is_ok(), "Should handle large diff"); let request = result.unwrap(); assert!(request.prompt.len() > 10000, "Large diff should be preserved"); // Count tokens to ensure we can handle large inputs - let model = Model::GPT4o; + let model = Model::GPT45; let token_count = model.count_tokens(&diff).unwrap(); assert!(token_count > 1000, "Large diff should have substantial token count"); } @@ -356,7 +356,7 @@ fn test_create_request_with_very_long_lines() { long_line ); - let result = create_commit_request(diff.clone(), 5000, Model::GPT4o); + let result = create_commit_request(diff.clone(), 5000, Model::GPT45); assert!(result.is_ok(), "Should handle very long lines"); let request = result.unwrap(); @@ -400,7 +400,7 @@ Binary files a/image.png and b/image.png differ "# .to_string(); - let result = create_commit_request(diff.clone(), 4000, Model::GPT4o); + let result = create_commit_request(diff.clone(), 4000, Model::GPT45); assert!(result.is_ok(), "Should handle mixed operations"); let request = result.unwrap(); @@ -415,7 +415,7 @@ Binary files a/image.png and b/image.png differ #[test] fn test_token_counting_consistency_with_complex_diff() { - let model = Model::GPT4oMini; + let model = Model::GPT41Mini; let complex_diff = r#"diff --git a/src/main.rs b/src/main.rs index abc123..def456 100644 @@ -486,7 +486,7 @@ index 777..888 100644 "# .to_string(); - let result = create_commit_request(diff.clone(), 5000, Model::GPT4o); + let result = create_commit_request(diff.clone(), 5000, Model::GPT45); assert!(result.is_ok(), "Should handle multiple programming languages"); let request = result.unwrap(); @@ -513,13 +513,13 @@ fn test_template_contains_required_sections() { #[test] fn test_request_structure_completeness() { let diff = "diff --git a/test.txt b/test.txt\n+test".to_string(); - let request = create_commit_request(diff.clone(), 1000, Model::GPT4oMini).unwrap(); + let request = create_commit_request(diff.clone(), 1000, Model::GPT41Mini).unwrap(); // Verify request has all required components assert!(!request.system.is_empty(), "System prompt should not be empty"); assert_eq!(request.prompt, diff, "User prompt should match input diff"); assert_eq!(request.max_tokens, 1000, "Max tokens should be set correctly"); - assert_eq!(request.model, Model::GPT4oMini, "Model should be set correctly"); + assert_eq!(request.model, Model::GPT41Mini, "Model should be set correctly"); // Verify system prompt has reasonable length assert!(request.system.len() > 500, "System prompt should be substantial"); @@ -546,7 +546,7 @@ index 123abc..456def 100644 .to_string(); // Test the full workflow - let model = Model::GPT4oMini; + let model = Model::GPT41Mini; let template = get_instruction_template().unwrap(); let token_count = token_used(&model).unwrap(); let request = create_commit_request(simple_diff.clone(), 2000, model).unwrap(); @@ -562,7 +562,7 @@ index 123abc..456def 100644 #[test] fn test_end_to_end_with_token_limits() { // Test that we can calculate tokens for both template and diff - let model = Model::GPT4o; + let model = Model::GPT45; let diff = r#"diff --git a/src/main.rs b/src/main.rs index abc..def 100644 --- a/src/main.rs diff --git a/tests/model_token_test.rs b/tests/model_token_test.rs index e3ca955e..3dc1ddae 100644 --- a/tests/model_token_test.rs +++ b/tests/model_token_test.rs @@ -2,7 +2,7 @@ use ai::model::Model; #[test] fn test_token_counting_accuracy() { - let model = Model::GPT4; + let model = Model::GPT41; // Test various text lengths to ensure we're not underestimating let test_cases = vec![ @@ -46,7 +46,7 @@ fn test_token_counting_accuracy() { #[test] fn test_no_underestimation_for_context_limit() { - let model = Model::GPT4; + let model = Model::GPT41; // Create text that would be underestimated by the old heuristics // Old heuristic: ~4 chars per token, but actual can be much different @@ -79,7 +79,7 @@ fn test_no_underestimation_for_context_limit() { #[test] fn test_token_counting_consistency() { - let model = Model::GPT4; + let model = Model::GPT41; // Test that the same text always returns the same token count let test_text = "The quick brown fox jumps over the lazy dog. This is a test sentence with various words."; @@ -95,7 +95,7 @@ fn test_token_counting_consistency() { #[test] fn test_long_text_token_counting() { - let model = Model::GPT4; + let model = Model::GPT41; // Test with a longer text to ensure we're using the tokenizer properly let long_text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. ".repeat(50); diff --git a/tests/model_validation_test.rs b/tests/model_validation_test.rs new file mode 100644 index 00000000..6497b4e2 --- /dev/null +++ b/tests/model_validation_test.rs @@ -0,0 +1,102 @@ +use std::str::FromStr; + +use ai::model::Model; + +#[test] +fn test_valid_model_names() { + // Test all supported model names + assert_eq!(Model::from_str("gpt-4.1").unwrap(), Model::GPT41); + assert_eq!(Model::from_str("gpt-4.1-mini").unwrap(), Model::GPT41Mini); + assert_eq!(Model::from_str("gpt-4.1-nano").unwrap(), Model::GPT41Nano); + assert_eq!(Model::from_str("gpt-4.5").unwrap(), Model::GPT45); +} + +#[test] +fn test_case_insensitive_parsing() { + // Test that model names are case-insensitive + assert_eq!(Model::from_str("GPT-4.1").unwrap(), Model::GPT41); + assert_eq!(Model::from_str("Gpt-4.1-Mini").unwrap(), Model::GPT41Mini); + assert_eq!(Model::from_str("GPT-4.1-NANO").unwrap(), Model::GPT41Nano); + assert_eq!(Model::from_str("gPt-4.5").unwrap(), Model::GPT45); +} + +#[test] +fn test_whitespace_handling() { + // Test that leading/trailing whitespace is trimmed + assert_eq!(Model::from_str(" gpt-4.1 ").unwrap(), Model::GPT41); + assert_eq!(Model::from_str("\tgpt-4.1-mini\n").unwrap(), Model::GPT41Mini); +} + +#[test] +fn test_deprecated_model_backward_compat() { + // Test that deprecated models map to their GPT-4.1 equivalents + // These should succeed but log warnings + assert_eq!(Model::from_str("gpt-4").unwrap(), Model::GPT41); + assert_eq!(Model::from_str("gpt-4o").unwrap(), Model::GPT41); + assert_eq!(Model::from_str("gpt-4o-mini").unwrap(), Model::GPT41Mini); + assert_eq!(Model::from_str("gpt-3.5-turbo").unwrap(), Model::GPT41Mini); +} + +#[test] +fn test_invalid_model_name() { + // Test that invalid model names return an error + let result = Model::from_str("does-not-exist"); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Invalid model name")); +} + +#[test] +fn test_invalid_model_fallback() { + // Test that From<&str> falls back to default for invalid models + let model = Model::from("invalid-model"); + assert_eq!(model, Model::default()); + assert_eq!(model, Model::GPT41); +} + +#[test] +fn test_model_display() { + // Test that models display correctly + assert_eq!(Model::GPT41.to_string(), "gpt-4.1"); + assert_eq!(Model::GPT41Mini.to_string(), "gpt-4.1-mini"); + assert_eq!(Model::GPT41Nano.to_string(), "gpt-4.1-nano"); + assert_eq!(Model::GPT45.to_string(), "gpt-4.5"); +} + +#[test] +fn test_model_as_str() { + // Test the as_str() method + assert_eq!(Model::GPT41.as_str(), "gpt-4.1"); + assert_eq!(Model::GPT41Mini.as_str(), "gpt-4.1-mini"); + assert_eq!(Model::GPT41Nano.as_str(), "gpt-4.1-nano"); + assert_eq!(Model::GPT45.as_str(), "gpt-4.5"); +} + +#[test] +fn test_model_as_ref() { + // Test the AsRef implementation + fn takes_str_ref>(s: S) -> String { + s.as_ref().to_string() + } + + assert_eq!(takes_str_ref(&Model::GPT41), "gpt-4.1"); + assert_eq!(takes_str_ref(&Model::GPT41Mini), "gpt-4.1-mini"); +} + +#[test] +fn test_model_from_string() { + // Test conversion from String + let s = String::from("gpt-4.1"); + assert_eq!(Model::from(s), Model::GPT41); + + let s = String::from("gpt-4.1-mini"); + assert_eq!(Model::from(s), Model::GPT41Mini); +} + +#[test] +fn test_default_model() { + // Test that the default model is GPT41 + assert_eq!(Model::default(), Model::GPT41); +}