Skip to content

Commit

Permalink
fix: resolve Model Issues and add huggingface dependency (langflow-ai…
Browse files Browse the repository at this point in the history
…#2339)

* chore: adding default values to Azure OpenAI mandatory component

* fix: huggingface model component:
  - Change Huggingface-hub version from 0.20.0 to 0.22.0;
  - Internal model_id resolver not working, create a field to model_id;

* feat: add HuggingFace as extra dependency

* chore: remove redundant atribution on children

* fix: remove user environment variables from ChatLiteLLMModelComponent

---------

Co-authored-by: joaoguilhermeS <j.guilherme.s.oliveira2@gmail.com>
(cherry picked from commit 805df82)
  • Loading branch information
berrytern authored and nicoloboschi committed Jul 2, 2024
1 parent 917c06e commit 9aa8799
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 66 deletions.
43 changes: 31 additions & 12 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ python = ">=3.10,<3.13"
beautifulsoup4 = "^4.12.2"
google-search-results = "^2.4.1"
google-api-python-client = "^2.130.0"
huggingface-hub = { version = "^0.20.0", extras = ["inference"] }
huggingface-hub = { version = "^0.22.0", extras = ["inference"] }
llama-cpp-python = { version = "~0.2.0", optional = true }
networkx = "^3.1"
fake-useragent = "^1.5.0"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

from langchain_community.chat_models.litellm import ChatLiteLLM, ChatLiteLLMException

from langflow.base.constants import STREAM_INFO_TEXT
Expand All @@ -12,7 +10,6 @@
FloatInput,
IntInput,
MessageInput,
Output,
SecretStrInput,
StrInput,
)
Expand All @@ -35,7 +32,7 @@ class ChatLiteLLMModelComponent(LCModelComponent):
),
SecretStrInput(
name="api_key",
display_name="API key",
display_name="API Key",
advanced=False,
required=False,
),
Expand All @@ -60,24 +57,23 @@ class ChatLiteLLMModelComponent(LCModelComponent):
value=0.7,
),
DictInput(
name="model_kwargs",
display_name="Model kwargs",
name="kwargs",
display_name="Kwargs",
advanced=True,
required=False,
is_list=True,
value={},
),
FloatInput(
name="top_p",
display_name="Top p",
advanced=True,
required=False,
),
IntInput(
name="top_k",
display_name="Top k",
DictInput(
name="model_kwargs",
display_name="Model kwargs",
advanced=True,
required=False,
is_list=True,
value={},
),
FloatInput(name="top_p", display_name="Top p", advanced=True, required=False, value=0.5),
IntInput(name="top_k", display_name="Top k", advanced=True, required=False, value=35),
IntInput(
name="n",
display_name="N",
Expand Down Expand Up @@ -122,11 +118,6 @@ class ChatLiteLLMModelComponent(LCModelComponent):
),
]

outputs = [
Output(display_name="Text", name="text_output", method="text_response"),
Output(display_name="Language Model", name="model_output", method="build_model"),
]

def build_model(self) -> LanguageModel: # type: ignore[type-var]
try:
import litellm # type: ignore
Expand All @@ -137,28 +128,19 @@ def build_model(self) -> LanguageModel: # type: ignore[type-var]
raise ChatLiteLLMException(
"Could not import litellm python package. " "Please install it with `pip install litellm`"
)

provider_map = {
"OpenAI": "openai_api_key",
"Azure": "azure_api_key",
"Anthropic": "anthropic_api_key",
"Replicate": "replicate_api_key",
"Cohere": "cohere_api_key",
"OpenRouter": "openrouter_api_key",
}

# Set the API key based on the provider
api_keys: dict[str, Optional[str]] = {v: None for v in provider_map.values()}

if variable_name := provider_map.get(self.provider):
api_keys[variable_name] = self.api_key
else:
raise ChatLiteLLMException(
f"Provider {self.provider} is not supported. Supported providers are: {', '.join(provider_map.keys())}"
)

# Remove empty keys
if "" in self.kwargs:
del self.kwargs[""]
if "" in self.model_kwargs:
del self.model_kwargs[""]
# Report missing fields for Azure provider
if self.provider == "Azure":
if "api_base" not in self.kwargs:
raise Exception("Missing api_base on kwargs")
if "api_version" not in self.model_kwargs:
raise Exception("Missing api_version on model_kwargs")
output = ChatLiteLLM(
model=self.model,
model=f"{self.provider.lower()}/{self.model}",
client=None,
streaming=self.stream,
temperature=self.temperature,
Expand All @@ -168,12 +150,8 @@ def build_model(self) -> LanguageModel: # type: ignore[type-var]
n=self.n,
max_tokens=self.max_tokens,
max_retries=self.max_retries,
openai_api_key=api_keys["openai_api_key"],
azure_api_key=api_keys["azure_api_key"],
anthropic_api_key=api_keys["anthropic_api_key"],
replicate_api_key=api_keys["replicate_api_key"],
cohere_api_key=api_keys["cohere_api_key"],
openrouter_api_key=api_keys["openrouter_api_key"],
**self.kwargs,
)
output.client.api_key = self.api_key

return output # type: ignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import LanguageModel
from langflow.inputs import MessageTextInput
from langflow.io import BoolInput, DictInput, DropdownInput, MessageInput, Output
from langflow.io import BoolInput, DictInput, DropdownInput, MessageInput


class AmazonBedrockComponent(LCModelComponent):
Expand Down Expand Up @@ -64,10 +64,6 @@ class AmazonBedrockComponent(LCModelComponent):
),
BoolInput(name="stream", display_name="Stream", info=STREAM_INFO_TEXT, advanced=True),
]
outputs = [
Output(display_name="Text", name="text_output", method="text_response"),
Output(display_name="Language Model", name="model_output", method="build_model"),
]

def build_model(self) -> LanguageModel: # type: ignore[type-var]
model_id = self.model_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ class HuggingFaceEndpointsComponent(LCModelComponent):
inputs = [
MessageInput(name="input_value", display_name="Input"),
SecretStrInput(name="endpoint_url", display_name="Endpoint URL", password=True),
StrInput(
name="model_id",
display_name="Model Id",
info="Id field of endpoint_url response.",
),
DropdownInput(
name="task",
display_name="Task",
Expand Down Expand Up @@ -47,5 +52,5 @@ def build_model(self) -> LanguageModel: # type: ignore[type-var]
except Exception as e:
raise ValueError("Could not connect to HuggingFace Endpoints API.") from e

output = ChatHuggingFace(llm=llm)
output = ChatHuggingFace(llm=llm, model_id=self.model_id)
return output # type: ignore

0 comments on commit 9aa8799

Please sign in to comment.