Skip to content

Commit

Permalink
Add option to prefix config keys in configurable_alts (#13714)
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Nov 27, 2023
1 parent 4ce5254 commit 8a3e0c9
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 21 deletions.
2 changes: 1 addition & 1 deletion libs/core/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ tests:
poetry run pytest $(TEST_FILE)

test_watch:
poetry run ptw --snapshot-update --now . -- -x tests/unit_tests
poetry run ptw --snapshot-update --now . -- -vv -x tests/unit_tests


######################
Expand Down
8 changes: 7 additions & 1 deletion libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,15 +1204,21 @@ def configurable_fields(
def configurable_alternatives(
self,
which: ConfigurableField,
*,
default_key: str = "default",
prefix_keys: bool = False,
**kwargs: Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]],
) -> RunnableSerializable[Input, Output]:
from langchain_core.runnables.configurable import (
RunnableConfigurableAlternatives,
)

return RunnableConfigurableAlternatives(
which=which, default=self, alternatives=kwargs, default_key=default_key
which=which,
default=self,
alternatives=kwargs,
default_key=default_key,
prefix_keys=prefix_keys,
)


Expand Down
72 changes: 57 additions & 15 deletions libs/core/langchain_core/runnables/configurable.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def config_specs(self) -> List[ConfigurableFieldSpec]:
annotation=spec.annotation
or self.default.__fields__[field_name].annotation,
default=getattr(self.default, field_name),
is_shared=spec.is_shared,
)
if isinstance(spec, ConfigurableField)
else make_options_spec(
Expand Down Expand Up @@ -298,6 +299,12 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
]

default_key: str = "default"
"""The enum value to use for the default option. Defaults to "default"."""

prefix_keys: bool
"""Whether to prefix configurable fields of each alternative with a namespace
of the form <which.id>==<alternative_key>, eg. a key named "temperature" used by
the alternative named "gpt3" becomes "model==gpt3/temperature"."""

@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
Expand All @@ -313,21 +320,37 @@ def config_specs(self) -> List[ConfigurableFieldSpec]:
),
)
_enums_for_spec[self.which] = cast(Type[StrEnum], which_enum)
return [
ConfigurableFieldSpec(
id=self.which.id,
name=self.which.name,
description=self.which.description,
annotation=which_enum,
default=self.default_key,
),
*self.default.config_specs,
] + [
s
for alt in self.alternatives.values()
if isinstance(alt, RunnableSerializable)
for s in alt.config_specs
]
return get_unique_config_specs(
# which alternative
[
ConfigurableFieldSpec(
id=self.which.id,
name=self.which.name,
description=self.which.description,
annotation=which_enum,
default=self.default_key,
is_shared=self.which.is_shared,
),
]
# config specs of the default option
+ (
[
prefix_config_spec(s, f"{self.which.id}=={self.default_key}")
for s in self.default.config_specs
]
if self.prefix_keys
else self.default.config_specs
)
# config specs of the alternatives
+ [
prefix_config_spec(s, f"{self.which.id}=={alt_key}")
if self.prefix_keys
else s
for alt_key, alt in self.alternatives.items()
if isinstance(alt, RunnableSerializable)
for s in alt.config_specs
]
)

def configurable_fields(
self, **kwargs: AnyConfigurableField
Expand Down Expand Up @@ -355,6 +378,23 @@ def _prepare(
raise ValueError(f"Unknown alternative: {which}")


def prefix_config_spec(
spec: ConfigurableFieldSpec, prefix: str
) -> ConfigurableFieldSpec:
return (
ConfigurableFieldSpec(
id=f"{prefix}/{spec.id}",
name=spec.name,
description=spec.description,
annotation=spec.annotation,
default=spec.default,
is_shared=spec.is_shared,
)
if not spec.is_shared
else spec
)


def make_options_spec(
spec: Union[ConfigurableFieldSingleOption, ConfigurableFieldMultiOption],
description: Optional[str],
Expand All @@ -377,6 +417,7 @@ def make_options_spec(
description=spec.description or description,
annotation=enum,
default=spec.default,
is_shared=spec.is_shared,
)
else:
return ConfigurableFieldSpec(
Expand All @@ -385,4 +426,5 @@ def make_options_spec(
description=spec.description or description,
annotation=Sequence[enum], # type: ignore[valid-type]
default=spec.default,
is_shared=spec.is_shared,
)
1 change: 1 addition & 0 deletions libs/core/langchain_core/runnables/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def config_specs(self) -> List[ConfigurableFieldSpec]:
name="Session ID",
description="Unique identifier for a session.",
default="",
is_shared=True,
),
]
)
Expand Down
12 changes: 8 additions & 4 deletions libs/core/langchain_core/runnables/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ class ConfigurableField(NamedTuple):
name: Optional[str] = None
description: Optional[str] = None
annotation: Optional[Any] = None
is_shared: bool = False

def __hash__(self) -> int:
return hash((self.id, self.annotation))
Expand All @@ -271,6 +272,7 @@ class ConfigurableFieldSingleOption(NamedTuple):

name: Optional[str] = None
description: Optional[str] = None
is_shared: bool = False

def __hash__(self) -> int:
return hash((self.id, tuple(self.options.keys()), self.default))
Expand All @@ -285,6 +287,7 @@ class ConfigurableFieldMultiOption(NamedTuple):

name: Optional[str] = None
description: Optional[str] = None
is_shared: bool = False

def __hash__(self) -> int:
return hash((self.id, tuple(self.options.keys()), tuple(self.default)))
Expand All @@ -299,12 +302,13 @@ class ConfigurableFieldSpec(NamedTuple):
"""A field that can be configured by the user. It is a specification of a field."""

id: str
name: Optional[str]
description: Optional[str]

default: Any
annotation: Any

name: Optional[str] = None
description: Optional[str] = None
default: Any = None
is_shared: bool = False


def get_unique_config_specs(
specs: Iterable[ConfigurableFieldSpec],
Expand Down
112 changes: 112 additions & 0 deletions libs/core/tests/unit_tests/runnables/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,118 @@ def test_configurable_alts_factory() -> None:
assert fake_llm.with_config(configurable={"llm": "chat"}).invoke("...") == "b"


def test_configurable_fields_prefix_keys() -> None:
fake_chat = FakeListChatModel(responses=["b"]).configurable_fields(
responses=ConfigurableFieldMultiOption(
id="responses",
name="Chat Responses",
options={
"hello": "A good morning to you!",
"bye": "See you later!",
"helpful": "How can I help you?",
},
default=["hello", "bye"],
),
# (sleep is a configurable field in FakeListChatModel)
sleep=ConfigurableField(
id="chat_sleep",
is_shared=True,
),
)
fake_llm = (
FakeListLLM(responses=["a"])
.configurable_fields(
responses=ConfigurableField(
id="responses",
name="LLM Responses",
description="A list of fake responses for this LLM",
)
)
.configurable_alternatives(
ConfigurableField(id="llm", name="LLM"),
chat=fake_chat | StrOutputParser(),
prefix_keys=True,
)
)
prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields(
template=ConfigurableFieldSingleOption(
id="prompt_template",
name="Prompt Template",
description="The prompt template for this chain",
options={
"hello": "Hello, {name}!",
"good_morning": "A very good morning to you, {name}!",
},
default="hello",
)
)

chain = prompt | fake_llm

assert chain.config_schema().schema() == {
"title": "RunnableSequenceConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
"definitions": {
"LLM": {
"title": "LLM",
"description": "An enumeration.",
"enum": ["chat", "default"],
"type": "string",
},
"Chat_Responses": {
"title": "Chat Responses",
"description": "An enumeration.",
"enum": ["hello", "bye", "helpful"],
"type": "string",
},
"Prompt_Template": {
"title": "Prompt Template",
"description": "An enumeration.",
"enum": ["hello", "good_morning"],
"type": "string",
},
"Configurable": {
"title": "Configurable",
"type": "object",
"properties": {
"prompt_template": {
"title": "Prompt Template",
"description": "The prompt template for this chain",
"default": "hello",
"allOf": [{"$ref": "#/definitions/Prompt_Template"}],
},
"llm": {
"title": "LLM",
"default": "default",
"allOf": [{"$ref": "#/definitions/LLM"}],
},
# not prefixed because marked as shared
"chat_sleep": {
"title": "Chat Sleep",
"type": "number",
},
# prefixed for "chat" option
"llm==chat/responses": {
"title": "Chat Responses",
"default": ["hello", "bye"],
"type": "array",
"items": {"$ref": "#/definitions/Chat_Responses"},
},
# prefixed for "default" option
"llm==default/responses": {
"title": "LLM Responses",
"description": "A list of fake responses for this LLM",
"default": ["a"],
"type": "array",
"items": {"type": "string"},
},
},
},
},
}


def test_configurable_fields_example() -> None:
fake_chat = FakeListChatModel(responses=["b"]).configurable_fields(
responses=ConfigurableFieldMultiOption(
Expand Down

0 comments on commit 8a3e0c9

Please sign in to comment.