From 32b64692c5862c93c30a98ed06e0a324773b5e06 Mon Sep 17 00:00:00 2001 From: Kevin Beaulieu Date: Thu, 14 Nov 2024 23:10:31 -0500 Subject: [PATCH 1/2] Lazy import models to speed up startup --- src/git_ai_summarize/models.py | 90 +++++++++++++++++++--------------- 1 file changed, 50 insertions(+), 40 deletions(-) diff --git a/src/git_ai_summarize/models.py b/src/git_ai_summarize/models.py index a7eda82..3e2daa3 100644 --- a/src/git_ai_summarize/models.py +++ b/src/git_ai_summarize/models.py @@ -1,20 +1,5 @@ import os from langchain_core.language_models.chat_models import BaseChatModel -from langchain_anthropic import ChatAnthropic -from langchain_openai import ChatOpenAI -from langchain_google_genai import ChatGoogleGenerativeAI -from langchain_mistralai import ChatMistralAI -from langchain_fireworks import ChatFireworks -from langchain_together import ChatTogether -from langchain_google_vertexai import ChatVertexAI -from langchain_groq import ChatGroq -from langchain_nvidia_ai_endpoints import ChatNVIDIA -from langchain_ollama import ChatOllama -from langchain_ai21 import ChatAI21 -from langchain_upstage import ChatUpstage -from langchain_databricks import ChatDatabricks -from langchain_ibm import ChatWatsonx -from langchain_xai import ChatXAI from typing import List @@ -24,12 +9,12 @@ def get_supported_providers() -> List[str]: "anthropic", "openai", "google", - "mistralai", + "mistral", "fireworks", "together", - "vertexai", + "vertex", "groq", - "nvidia_ai", + "nvidia", "ollama", "ai21", "upstage", @@ -41,30 +26,55 @@ def get_supported_providers() -> List[str]: def get_model(provider_name: str | None, model_name: str | None) -> BaseChatModel: """Initialize and configure the LangChain components with specified model.""" - providers = { - "anthropic": (ChatAnthropic, "ANTHROPIC_API_KEY", "https://www.anthropic.com", "anthropic_api_key"), - "openai": (ChatOpenAI, "OPENAI_API_KEY", "https://platform.openai.com/account/api-keys", "openai_api_key"), - "google": (ChatGoogleGenerativeAI, "GOOGLE_API_KEY", "https://developers.generativeai.google/", "google_api_key"), - "mistral": (ChatMistralAI, "MISTRAL_API_KEY", "https://console.mistral.ai/api-keys/", "mistral_api_key"), - "fireworks": (ChatFireworks, "FIREWORKS_API_KEY", "https://app.fireworks.ai/", "fireworks_api_key"), - "together": (ChatTogether, "TOGETHER_API_KEY", "https://api.together.xyz/", "together_api_key"), - "vertex": (ChatVertexAI, "GOOGLE_APPLICATION_CREDENTIALS", "https://cloud.google.com/vertex-ai", None), - "groq": (ChatGroq, "GROQ_API_KEY", "https://console.groq.com/", "groq_api_key"), - "nvidia": (ChatNVIDIA, "NVIDIA_API_KEY", "https://api.nvidia.com/", "nvidia_api_key"), - "ollama": (ChatOllama, None, "https://ollama.ai/", None), - "ai21": (ChatAI21, "AI21_API_KEY", "https://www.ai21.com/studio", "ai21_api_key"), - "upstage": (ChatUpstage, "UPSTAGE_API_KEY", "https://upstage.ai/", "upstage_api_key"), - "databricks": (ChatDatabricks, "DATABRICKS_TOKEN", "https://www.databricks.com/", "databricks_token"), - "watsonx": (ChatWatsonx, "WATSONX_API_KEY", "https://www.ibm.com/watsonx", "watsonx_api_key"), - "xai": (ChatXAI, "XAI_API_KEY", "https://xai.com/", "xai_api_key"), - } - - max_tokens = 500 - - if provider_name not in providers: + if provider_name == 'anthropic': + from langchain_anthropic import ChatAnthropic + model_class, api_key_env, api_url, api_key_param = ChatAnthropic, "ANTHROPIC_API_KEY", "https://www.anthropic.com", "anthropic_api_key" + elif provider_name == 'openai': + from langchain_openai import ChatOpenAI + model_class, api_key_env, api_url, api_key_param = ChatOpenAI, "OPENAI_API_KEY", "https://platform.openai.com/account/api-keys", "openai_api_key" + elif provider_name == 'google': + from langchain_google_genai import ChatGoogleGenerativeAI + model_class, api_key_env, api_url, api_key_param = ChatGoogleGenerativeAI, "GOOGLE_API_KEY", "https://developers.generativeai.google/", "google_api_key" + elif provider_name == 'mistral': + from langchain_mistralai import ChatMistralAI + model_class, api_key_env, api_url, api_key_param = ChatMistralAI, "MISTRAL_API_KEY", "https://console.mistral.ai/api-keys/", "mistral_api_key" + elif provider_name == 'fireworks': + from langchain_fireworks import ChatFireworks + model_class, api_key_env, api_url, api_key_param = ChatFireworks, "FIREWORKS_API_KEY", "https://app.fireworks.ai/", "fireworks_api_key" + elif provider_name == 'together': + from langchain_together import ChatTogether + model_class, api_key_env, api_url, api_key_param = ChatTogether, "TOGETHER_API_KEY", "https://api.together.xyz/", "together_api_key" + elif provider_name == 'vertex': + from langchain_google_vertexai import ChatVertexAI + model_class, api_key_env, api_url, api_key_param = ChatVertexAI, "GOOGLE_APPLICATION_CREDENTIALS", "https://cloud.google.com/vertex-ai", None + elif provider_name == 'groq': + from langchain_groq import ChatGroq + model_class, api_key_env, api_url, api_key_param = ChatGroq, "GROQ_API_KEY", "https://console.groq.com/", "groq_api_key" + elif provider_name == 'nvidia': + from langchain_nvidia_ai_endpoints import ChatNVIDIA + model_class, api_key_env, api_url, api_key_param = ChatNVIDIA, "NVIDIA_API_KEY", "https://api.nvidia.com/", "nvidia_api_key" + elif provider_name == 'ollama': + from langchain_ollama import ChatOllama + model_class, api_key_env, api_url, api_key_param = ChatOllama, None, "https://ollama.ai/", None + elif provider_name == 'ai21': + from langchain_ai21 import ChatAI21 + model_class, api_key_env, api_url, api_key_param = ChatAI21, "AI21_API_KEY", "https://www.ai21.com/studio", "ai21_api_key" + elif provider_name == 'upstage': + from langchain_upstage import ChatUpstage + model_class, api_key_env, api_url, api_key_param = ChatUpstage, "UPSTAGE_API_KEY", "https://upstage.ai/", "upstage_api_key" + elif provider_name == 'databricks': + from langchain_databricks import ChatDatabricks + model_class, api_key_env, api_url, api_key_param = ChatDatabricks, "DATABRICKS_TOKEN", "https://www.databricks.com/", "databricks_token" + elif provider_name == 'watsonx': + from langchain_watsonx import ChatWatsonx + model_class, api_key_env, api_url, api_key_param = ChatWatsonx, "WATSONX_API_KEY", "https://www.ibm.com/watsonx", "watsonx_api_key" + elif provider_name == 'xai': + from langchain_xai import ChatXAI + model_class, api_key_env, api_url, api_key_param = ChatXAI, "XAI_API_KEY", "https://xai.com/", "xai_api_key" + else: raise ValueError(f"Unsupported LLM provider: {provider_name}") - model_class, api_key_env, api_url, api_key_param = providers[provider_name] + max_tokens = 500 if api_key_env: api_key = os.getenv(api_key_env) From da11e7f9d1141522c38ebf48c8cd83334c160a38 Mon Sep 17 00:00:00 2001 From: Kevin Beaulieu Date: Thu, 14 Nov 2024 23:11:10 -0500 Subject: [PATCH 2/2] Bump version to 0.1.4 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6c42fe5..f142647 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "git-ai-summarize" -version = "0.1.3" +version = "0.1.4" authors = [{ name = "Kevin Beaulieu", email = "opensource@kevinmbeaulieu.com" }] description = "AI-powered git commands for summarizing changes" readme = "README.md"