Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "git-ai-summarize"
version = "0.1.3"
version = "0.1.4"
authors = [{ name = "Kevin Beaulieu", email = "opensource@kevinmbeaulieu.com" }]
description = "AI-powered git commands for summarizing changes"
readme = "README.md"
Expand Down
90 changes: 50 additions & 40 deletions src/git_ai_summarize/models.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,5 @@
import os
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_anthropic import ChatAnthropic
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_mistralai import ChatMistralAI
from langchain_fireworks import ChatFireworks
from langchain_together import ChatTogether
from langchain_google_vertexai import ChatVertexAI
from langchain_groq import ChatGroq
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langchain_ollama import ChatOllama
from langchain_ai21 import ChatAI21
from langchain_upstage import ChatUpstage
from langchain_databricks import ChatDatabricks
from langchain_ibm import ChatWatsonx
from langchain_xai import ChatXAI
from typing import List


Expand All @@ -24,12 +9,12 @@ def get_supported_providers() -> List[str]:
"anthropic",
"openai",
"google",
"mistralai",
"mistral",
"fireworks",
"together",
"vertexai",
"vertex",
"groq",
"nvidia_ai",
"nvidia",
"ollama",
"ai21",
"upstage",
Expand All @@ -41,30 +26,55 @@ def get_supported_providers() -> List[str]:

def get_model(provider_name: str | None, model_name: str | None) -> BaseChatModel:
"""Initialize and configure the LangChain components with specified model."""
providers = {
"anthropic": (ChatAnthropic, "ANTHROPIC_API_KEY", "https://www.anthropic.com", "anthropic_api_key"),
"openai": (ChatOpenAI, "OPENAI_API_KEY", "https://platform.openai.com/account/api-keys", "openai_api_key"),
"google": (ChatGoogleGenerativeAI, "GOOGLE_API_KEY", "https://developers.generativeai.google/", "google_api_key"),
"mistral": (ChatMistralAI, "MISTRAL_API_KEY", "https://console.mistral.ai/api-keys/", "mistral_api_key"),
"fireworks": (ChatFireworks, "FIREWORKS_API_KEY", "https://app.fireworks.ai/", "fireworks_api_key"),
"together": (ChatTogether, "TOGETHER_API_KEY", "https://api.together.xyz/", "together_api_key"),
"vertex": (ChatVertexAI, "GOOGLE_APPLICATION_CREDENTIALS", "https://cloud.google.com/vertex-ai", None),
"groq": (ChatGroq, "GROQ_API_KEY", "https://console.groq.com/", "groq_api_key"),
"nvidia": (ChatNVIDIA, "NVIDIA_API_KEY", "https://api.nvidia.com/", "nvidia_api_key"),
"ollama": (ChatOllama, None, "https://ollama.ai/", None),
"ai21": (ChatAI21, "AI21_API_KEY", "https://www.ai21.com/studio", "ai21_api_key"),
"upstage": (ChatUpstage, "UPSTAGE_API_KEY", "https://upstage.ai/", "upstage_api_key"),
"databricks": (ChatDatabricks, "DATABRICKS_TOKEN", "https://www.databricks.com/", "databricks_token"),
"watsonx": (ChatWatsonx, "WATSONX_API_KEY", "https://www.ibm.com/watsonx", "watsonx_api_key"),
"xai": (ChatXAI, "XAI_API_KEY", "https://xai.com/", "xai_api_key"),
}

max_tokens = 500

if provider_name not in providers:
if provider_name == 'anthropic':
from langchain_anthropic import ChatAnthropic
model_class, api_key_env, api_url, api_key_param = ChatAnthropic, "ANTHROPIC_API_KEY", "https://www.anthropic.com", "anthropic_api_key"
elif provider_name == 'openai':
from langchain_openai import ChatOpenAI
model_class, api_key_env, api_url, api_key_param = ChatOpenAI, "OPENAI_API_KEY", "https://platform.openai.com/account/api-keys", "openai_api_key"
elif provider_name == 'google':
from langchain_google_genai import ChatGoogleGenerativeAI
model_class, api_key_env, api_url, api_key_param = ChatGoogleGenerativeAI, "GOOGLE_API_KEY", "https://developers.generativeai.google/", "google_api_key"
elif provider_name == 'mistral':
from langchain_mistralai import ChatMistralAI
model_class, api_key_env, api_url, api_key_param = ChatMistralAI, "MISTRAL_API_KEY", "https://console.mistral.ai/api-keys/", "mistral_api_key"
elif provider_name == 'fireworks':
from langchain_fireworks import ChatFireworks
model_class, api_key_env, api_url, api_key_param = ChatFireworks, "FIREWORKS_API_KEY", "https://app.fireworks.ai/", "fireworks_api_key"
elif provider_name == 'together':
from langchain_together import ChatTogether
model_class, api_key_env, api_url, api_key_param = ChatTogether, "TOGETHER_API_KEY", "https://api.together.xyz/", "together_api_key"
elif provider_name == 'vertex':
from langchain_google_vertexai import ChatVertexAI
model_class, api_key_env, api_url, api_key_param = ChatVertexAI, "GOOGLE_APPLICATION_CREDENTIALS", "https://cloud.google.com/vertex-ai", None
elif provider_name == 'groq':
from langchain_groq import ChatGroq
model_class, api_key_env, api_url, api_key_param = ChatGroq, "GROQ_API_KEY", "https://console.groq.com/", "groq_api_key"
elif provider_name == 'nvidia':
from langchain_nvidia_ai_endpoints import ChatNVIDIA
model_class, api_key_env, api_url, api_key_param = ChatNVIDIA, "NVIDIA_API_KEY", "https://api.nvidia.com/", "nvidia_api_key"
elif provider_name == 'ollama':
from langchain_ollama import ChatOllama
model_class, api_key_env, api_url, api_key_param = ChatOllama, None, "https://ollama.ai/", None
elif provider_name == 'ai21':
from langchain_ai21 import ChatAI21
model_class, api_key_env, api_url, api_key_param = ChatAI21, "AI21_API_KEY", "https://www.ai21.com/studio", "ai21_api_key"
elif provider_name == 'upstage':
from langchain_upstage import ChatUpstage
model_class, api_key_env, api_url, api_key_param = ChatUpstage, "UPSTAGE_API_KEY", "https://upstage.ai/", "upstage_api_key"
elif provider_name == 'databricks':
from langchain_databricks import ChatDatabricks
model_class, api_key_env, api_url, api_key_param = ChatDatabricks, "DATABRICKS_TOKEN", "https://www.databricks.com/", "databricks_token"
elif provider_name == 'watsonx':
from langchain_watsonx import ChatWatsonx
model_class, api_key_env, api_url, api_key_param = ChatWatsonx, "WATSONX_API_KEY", "https://www.ibm.com/watsonx", "watsonx_api_key"
elif provider_name == 'xai':
from langchain_xai import ChatXAI
model_class, api_key_env, api_url, api_key_param = ChatXAI, "XAI_API_KEY", "https://xai.com/", "xai_api_key"
else:
raise ValueError(f"Unsupported LLM provider: {provider_name}")

model_class, api_key_env, api_url, api_key_param = providers[provider_name]
max_tokens = 500

if api_key_env:
api_key = os.getenv(api_key_env)
Expand Down