Skip to content

Commit

Permalink
Use an enum for configurable_alternatives to make the generated json …
Browse files Browse the repository at this point in the history
…schema nicer (#11350)
  • Loading branch information
nfcampos committed Oct 4, 2023
1 parent b499de2 commit b0893c7
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 52 deletions.
20 changes: 7 additions & 13 deletions libs/langchain/langchain/schema/runnable/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,7 @@ def output_schema(self) -> Type[BaseModel]:
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return []

def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]:
class _Config:
arbitrary_types_allowed = True

Expand All @@ -150,7 +148,7 @@ class _Config:
for spec in config_specs
},
)
if config_specs
if config_specs and "configurable" in include
else None
)

Expand All @@ -161,7 +159,7 @@ class _Config:
**{
field_name: (field_type, None)
for field_name, field_type in RunnableConfig.__annotations__.items()
if field_name in include
if field_name in [i for i in include if i != "configurable"]
},
)

Expand Down Expand Up @@ -873,7 +871,7 @@ def configurable_fields(
"available keys are {self.__fields__.keys()}"
)

return RunnableConfigurableFields(bound=self, fields=kwargs)
return RunnableConfigurableFields(default=self, fields=kwargs)

def configurable_alternatives(
self,
Expand All @@ -885,7 +883,7 @@ def configurable_alternatives(
)

return RunnableConfigurableAlternatives(
which=which, bound=self, alternatives=kwargs
which=which, default=self, alternatives=kwargs
)


Expand Down Expand Up @@ -2051,9 +2049,7 @@ def output_schema(self) -> type[BaseModel]:
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return self.bound.config_specs

def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]:
return self.bound.config_schema(include=include)

@classmethod
Expand Down Expand Up @@ -2132,9 +2128,7 @@ def output_schema(self) -> Type[BaseModel]:
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return self.bound.config_specs

def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]:
return self.bound.config_schema(include=include)

@classmethod
Expand Down
59 changes: 33 additions & 26 deletions libs/langchain/langchain/schema/runnable/configurable.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

import enum
from abc import abstractmethod
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Type,
Expand All @@ -32,7 +32,7 @@


class DynamicRunnable(RunnableSerializable[Input, Output]):
bound: RunnableSerializable[Input, Output]
default: RunnableSerializable[Input, Output]

class Config:
arbitrary_types_allowed = True
Expand All @@ -47,19 +47,19 @@ def get_lc_namespace(cls) -> List[str]:

@property
def InputType(self) -> Type[Input]:
return self.bound.InputType
return self.default.InputType

@property
def OutputType(self) -> Type[Output]:
return self.bound.OutputType
return self.default.OutputType

@property
def input_schema(self) -> Type[BaseModel]:
return self.bound.input_schema
return self.default.input_schema

@property
def output_schema(self) -> Type[BaseModel]:
return self.bound.output_schema
return self.default.output_schema

@abstractmethod
def _prepare(
Expand Down Expand Up @@ -88,8 +88,8 @@ def batch(
configs = get_config_list(config, len(inputs))
prepared = [self._prepare(c) for c in configs]

if all(p is self.bound for p in prepared):
return self.bound.batch(
if all(p is self.default for p in prepared):
return self.default.batch(
inputs, config, return_exceptions=return_exceptions, **kwargs
)

Expand Down Expand Up @@ -131,8 +131,8 @@ async def abatch(
configs = get_config_list(config, len(inputs))
prepared = [self._prepare(c) for c in configs]

if all(p is self.bound for p in prepared):
return await self.bound.abatch(
if all(p is self.default for p in prepared):
return await self.default.abatch(
inputs, config, return_exceptions=return_exceptions, **kwargs
)

Expand Down Expand Up @@ -202,18 +202,18 @@ def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
id=spec.id,
name=spec.name,
description=spec.description
or self.bound.__fields__[field_name].field_info.description,
or self.default.__fields__[field_name].field_info.description,
annotation=spec.annotation
or self.bound.__fields__[field_name].annotation,
default=getattr(self.bound, field_name),
or self.default.__fields__[field_name].annotation,
default=getattr(self.default, field_name),
)
for field_name, spec in self.fields.items()
]

def configurable_fields(
self, **kwargs: ConfigurableField
) -> RunnableSerializable[Input, Output]:
return self.bound.configurable_fields(**{**self.fields, **kwargs})
return self.default.configurable_fields(**{**self.fields, **kwargs})

def _prepare(
self, config: Optional[RunnableConfig] = None
Expand All @@ -227,49 +227,56 @@ def _prepare(
}

if configurable:
return self.bound.__class__(**{**self.bound.dict(), **configurable})
return self.default.__class__(**{**self.default.dict(), **configurable})
else:
return self.bound
return self.default


# Before Python 3.11 native StrEnum is not available
class StrEnum(str, enum.Enum):
pass


class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
which: ConfigurableField

alternatives: Dict[str, RunnableSerializable[Input, Output]]

default_key: str = "default"

@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
alt_keys = self.alternatives.keys()
which_keys = tuple(Literal[k] for k in alt_keys) + ( # type: ignore
Literal["default"],
which_enum = StrEnum( # type: ignore[call-overload]
self.which.name or self.which.id,
((v, v) for v in list(self.alternatives.keys()) + [self.default_key]),
)
return [
ConfigurableFieldSpec(
id=self.which.id,
name=self.which.name,
description=self.which.description,
annotation=Union[which_keys], # type: ignore
default="default",
annotation=which_enum,
default=self.default_key,
),
*self.bound.config_specs,
*self.default.config_specs,
] + [s for alt in self.alternatives.values() for s in alt.config_specs]

def configurable_fields(
self, **kwargs: ConfigurableField
) -> RunnableSerializable[Input, Output]:
return self.__class__(
which=self.which,
bound=self.bound.configurable_fields(**kwargs),
default=self.default.configurable_fields(**kwargs),
alternatives=self.alternatives,
)

def _prepare(
self, config: Optional[RunnableConfig] = None
) -> Runnable[Input, Output]:
config = config or {}
which = config.get("configurable", {}).get(self.which.id)
if not which:
return self.bound
which = str(config.get("configurable", {}).get(self.which.id, self.default_key))
if which == self.default_key:
return self.default
elif which in self.alternatives:
return self.alternatives[which]
else:
Expand Down
4 changes: 1 addition & 3 deletions libs/langchain/langchain/schema/runnable/fallbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
for spec in step.config_specs
)

def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]:
return self.runnable.config_schema(include=include)

@classmethod
Expand Down
25 changes: 15 additions & 10 deletions libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def test_configurable_fields() -> None:

assert fake_llm_configurable.invoke("...") == "a"

assert fake_llm_configurable.config_schema().schema() == {
assert fake_llm_configurable.config_schema(include=["configurable"]).schema() == {
"title": "RunnableConfigurableFieldsConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
Expand Down Expand Up @@ -606,7 +606,7 @@ def test_configurable_fields() -> None:
text="Hello, John!"
)

assert prompt_configurable.config_schema().schema() == {
assert prompt_configurable.config_schema(include=["configurable"]).schema() == {
"title": "RunnableConfigurableFieldsConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
Expand Down Expand Up @@ -638,7 +638,7 @@ def test_configurable_fields() -> None:

assert chain_configurable.invoke({"name": "John"}) == "a"

assert chain_configurable.config_schema().schema() == {
assert chain_configurable.config_schema(include=["configurable"]).schema() == {
"title": "RunnableSequenceConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
Expand Down Expand Up @@ -690,7 +690,9 @@ def test_configurable_fields() -> None:
"llm3": "a",
}

assert chain_with_map_configurable.config_schema().schema() == {
assert chain_with_map_configurable.config_schema(
include=["configurable"]
).schema() == {
"title": "RunnableSequenceConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
Expand Down Expand Up @@ -760,22 +762,25 @@ def test_configurable_fields_example() -> None:

assert chain_configurable.invoke({"name": "John"}) == "a"

assert chain_configurable.config_schema().schema() == {
assert chain_configurable.config_schema(include=["configurable"]).schema() == {
"title": "RunnableSequenceConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
"definitions": {
"LLM": {
"title": "LLM",
"description": "An enumeration.",
"enum": ["chat", "default"],
"type": "string",
},
"Configurable": {
"title": "Configurable",
"type": "object",
"properties": {
"llm": {
"title": "LLM",
"default": "default",
"anyOf": [
{"enum": ["chat"], "type": "string"},
{"enum": ["default"], "type": "string"},
],
"allOf": [{"$ref": "#/definitions/LLM"}],
},
"llm_responses": {
"title": "LLM Responses",
Expand All @@ -791,7 +796,7 @@ def test_configurable_fields_example() -> None:
"type": "string",
},
},
}
},
},
}

Expand Down

0 comments on commit b0893c7

Please sign in to comment.