Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ mod tests {
let result = generate(
"diff --git a/test.txt b/test.txt\n+Hello World".to_string(),
1024,
Model::GPT4oMini,
Model::GPT41Mini,
Some(&settings)
)
.await;
Expand Down Expand Up @@ -268,7 +268,7 @@ mod tests {
let result = generate(
"diff --git a/test.txt b/test.txt\n+Hello World".to_string(),
1024,
Model::GPT4oMini,
Model::GPT41Mini,
Some(&settings)
)
.await;
Expand Down
2 changes: 1 addition & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use console::Emoji;
const DEFAULT_TIMEOUT: i64 = 30;
const DEFAULT_MAX_COMMIT_LENGTH: i64 = 72;
const DEFAULT_MAX_TOKENS: i64 = 2024;
const DEFAULT_MODEL: &str = "gpt-4o-mini";
const DEFAULT_MODEL: &str = "gpt-4.1"; // Matches Model::default()
const DEFAULT_API_KEY: &str = "<PLACE HOLDER FOR YOUR API KEY>";

#[derive(Debug, Default, Deserialize, PartialEq, Eq, Serialize)]
Expand Down
95 changes: 65 additions & 30 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,26 @@ use crate::config::AppConfig;
static TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();

// Model identifiers - using screaming case for constants
const MODEL_GPT4: &str = "gpt-4";
const MODEL_GPT4_OPTIMIZED: &str = "gpt-4o";
const MODEL_GPT4_MINI: &str = "gpt-4o-mini";
const MODEL_GPT4_1: &str = "gpt-4.1";
const MODEL_GPT4_1_MINI: &str = "gpt-4.1-mini";
const MODEL_GPT4_1_NANO: &str = "gpt-4.1-nano";
const MODEL_GPT4_5: &str = "gpt-4.5";
// TODO: Get this from config.rs or a shared constants module
const DEFAULT_MODEL_NAME: &str = "gpt-4.1";

/// Represents the available AI models for commit message generation.
/// Each model has different capabilities and token limits.
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, Serialize, Deserialize, Default)]
pub enum Model {
/// Standard GPT-4 model
GPT4,
/// Optimized GPT-4 model for better performance
GPT4o,
/// Mini version of optimized GPT-4 for faster processing
GPT4oMini,
/// Default model - GPT-4.1 latest version
#[default]
GPT41
GPT41,
/// Mini version of GPT-4.1 for faster processing
GPT41Mini,
/// Nano version of GPT-4.1 for very fast processing
GPT41Nano,
/// GPT-4.5 model for advanced capabilities
GPT45
}

impl Model {
Expand All @@ -59,10 +59,7 @@ impl Model {

// Always use the proper tokenizer for accurate counts
// We cannot afford to underestimate tokens as it may cause API failures
let tokenizer = TOKENIZER.get_or_init(|| {
let model_str: &str = self.into();
get_tokenizer(model_str)
});
let tokenizer = TOKENIZER.get_or_init(|| get_tokenizer(self.as_ref()));

// Use direct tokenization for accurate token count
let tokens = tokenizer.encode_ordinary(text);
Expand All @@ -75,8 +72,7 @@ impl Model {
/// * `usize` - The maximum number of tokens the model can process
pub fn context_size(&self) -> usize {
profile!("Get context size");
let model_str: &str = self.into();
get_context_size(model_str)
get_context_size(self.as_ref())
}

/// Truncates the given text to fit within the specified token limit.
Expand Down Expand Up @@ -167,41 +163,80 @@ impl Model {
}
}

impl From<&Model> for &str {
fn from(model: &Model) -> Self {
match model {
Model::GPT4o => MODEL_GPT4_OPTIMIZED,
Model::GPT4 => MODEL_GPT4,
Model::GPT4oMini => MODEL_GPT4_MINI,
Model::GPT41 => MODEL_GPT4_1
impl AsRef<str> for Model {
fn as_ref(&self) -> &str {
match self {
Model::GPT41 => MODEL_GPT4_1,
Model::GPT41Mini => MODEL_GPT4_1_MINI,
Model::GPT41Nano => MODEL_GPT4_1_NANO,
Model::GPT45 => MODEL_GPT4_5
}
}
}

// Keep conversion to String for cases that need owned strings
impl From<&Model> for String {
fn from(model: &Model) -> Self {
model.as_ref().to_string()
}
}

// Keep the old impl for backwards compatibility where possible
impl Model {
pub fn as_str(&self) -> &str {
self.as_ref()
}
}

impl FromStr for Model {
type Err = anyhow::Error;

fn from_str(s: &str) -> Result<Self> {
match s.trim().to_lowercase().as_str() {
"gpt-4o" => Ok(Model::GPT4o),
"gpt-4" => Ok(Model::GPT4),
"gpt-4o-mini" => Ok(Model::GPT4oMini),
let normalized = s.trim().to_lowercase();
match normalized.as_str() {
"gpt-4.1" => Ok(Model::GPT41),
model => bail!("Invalid model name: {}", model)
"gpt-4.1-mini" => Ok(Model::GPT41Mini),
"gpt-4.1-nano" => Ok(Model::GPT41Nano),
"gpt-4.5" => Ok(Model::GPT45),
// Backward compatibility for deprecated models - map to closest GPT-4.1 equivalent
"gpt-4" | "gpt-4o" => {
log::warn!(
"Model '{}' is deprecated. Mapping to 'gpt-4.1'. \
Please update your configuration with: git ai config set model gpt-4.1",
s
);
Ok(Model::GPT41)
}
"gpt-4o-mini" | "gpt-3.5-turbo" => {
log::warn!(
"Model '{}' is deprecated. Mapping to 'gpt-4.1-mini'. \
Please update your configuration with: git ai config set model gpt-4.1-mini",
s
);
Ok(Model::GPT41Mini)
}
model =>
bail!(
"Invalid model name: '{}'. Supported models: gpt-4.1, gpt-4.1-mini, gpt-4.1-nano, gpt-4.5",
model
),
}
}
}

impl Display for Model {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", <&str>::from(self))
write!(f, "{}", self.as_ref())
}
}

// Implement conversion from string types to Model with fallback to default
impl From<&str> for Model {
fn from(s: &str) -> Self {
s.parse().unwrap_or_default()
s.parse().unwrap_or_else(|e| {
log::error!("Failed to parse model '{}': {}. Falling back to default model 'gpt-4.1'.", s, e);
Model::default()
})
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub async fn generate_commit_message(diff: &str) -> Result<String> {
if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
if !api_key.is_empty() {
// Use the commit function directly without parsing
match commit::generate(diff.to_string(), 256, Model::GPT4oMini, None).await {
match commit::generate(diff.to_string(), 256, Model::GPT41Mini, None).await {
Ok(response) => return Ok(response.response.trim().to_string()),
Err(e) => {
log::warn!("Direct generation failed, falling back to local: {e}");
Expand Down
Loading
Loading