Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Logic error when nesting Chain, and no error message #25603

Open
5 tasks done
luoling1993 opened this issue Aug 21, 2024 · 5 comments
Open
5 tasks done

[BUG] Logic error when nesting Chain, and no error message #25603

luoling1993 opened this issue Aug 21, 2024 · 5 comments
Labels
🤖:bug Related to a bug, vulnerability, unexpected error with an existing feature Ɑ: core Related to langchain-core

Comments

@luoling1993
Copy link

luoling1993 commented Aug 21, 2024

Checked other resources

  • I added a very descriptive title to this issue.
  • I searched the LangChain documentation with the integrated search.
  • I used the GitHub search to find a similar question and didn't find it.
  • I am sure that this is a bug in LangChain rather than my code.
  • The bug is not resolved by updating to the latest stable version of LangChain (or the specific integration package).

Example Code

from typing import Any, Dict, List, Literal

from langchain.chains.base import Chain
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.pydantic_v1 import root_validator


class _RegexExpandChain(Chain):
    param1: str | None = None

    @root_validator(pre=True)
    def validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        param1 = values.get("param1")
        if param1 is None:
            param1 = "default_regex_value"
        values["param1"] = param1
        return values

    @property
    def input_keys(self) -> List[str]:
        return ["question"]

    @property
    def output_keys(self) -> List[str]:
        return ["response"]

    def _call(self, inputs: Dict[str, Any], run_manager: CallbackManagerForChainRun) -> Dict[str, Any]:
        return {"response": "regex"}


class _LLMExpandChain(Chain):
    params2: str | None = None

    @root_validator(pre=True)
    def validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        params2 = values.get("params2")
        if params2 is None:
            params2 = "default_llm_value"
        values["params2"] = params2
        return values

    @property
    def input_keys(self) -> List[str]:
        return ["question"]

    @property
    def output_keys(self) -> List[str]:
        return ["response"]

    def _call(self, inputs: Dict[str, Any], run_manager: CallbackManagerForChainRun) -> Dict[str, Any]:
        return {"response": "llm"}


class ExpandChain(Chain):
    expand_type: Literal["regex", "llm"]

    chain: _RegexExpandChain  | _LLMExpandChain | None = None  # will always use _RegexExpandChain
    # chain: Chain | None = None  # will be correct

    @root_validator(pre=True)
    def validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        expand_type = values.get("expand_type")
        if expand_type not in ["regex", "llm"]:
            raise ValueError("expand_type must be either 'regex' or 'llm'")

        if expand_type == "regex":
            chain = _RegexExpandChain()
        else:
            chain = _LLMExpandChain()

        values["chain"] = chain
        return values

    @property
    def input_keys(self) -> List[str]:
        return ["question"]

    @property
    def output_keys(self) -> List[str]:
        return ["response"]

    def _call(self, inputs: Dict[str, Any], run_manager: CallbackManagerForChainRun) -> Dict[str, Any]:
        return self.chain.invoke(input=inputs, run_manager=run_manager)


if __name__ == "__main__":
    test_chain = ExpandChain(expand_type="regex")
    rsp = test_chain.invoke(input={"question": "test"})

    # chain: _RegexExpandChain | _LLMExpandChain | None = None
    # will always return {'question': 'test', 'response': 'regex'}
    #
    # chain: _LLMExpandChain | _RegexExpandChain | None = None
    # wiil always retuen {'question': 'test', 'response': 'llm'}
    print(rsp)

Error Message and Stack Trace (if applicable)

No any error message

Description

See example code and comment for details

System Info

langchain==0.2.14
langchain-community==0.2.12
langchain-core==0.2.33
langchain-openai==0.1.22
langchain-text-splitters==0.2.2

platform: ubuntu20.04
python version: python3.10

@dosubot dosubot bot added Ɑ: core Related to langchain-core 🤖:bug Related to a bug, vulnerability, unexpected error with an existing feature labels Aug 21, 2024
@luoling1993
Copy link
Author

luoling1993 commented Aug 21, 2024

Since there are no errors reported and no restrictions, this bug is a fatal bug.

@gbaian10
Copy link
Contributor

gbaian10 commented Aug 21, 2024

This is an issue with pydantic's Smart Union.

You can change pre=True to pre=False in ExpandChain to solve your problem.

class ExpandChain(Chain):
    expand_type: Literal["regex", "llm"]

    chain: _RegexExpandChain  | _LLMExpandChain | None = None

    @root_validator(pre=False)  # change pre = False

The following is a simple example.

from pydantic.v1 import BaseModel

def test(pre: bool, n):
    class Convert(BaseModel):
        number: str | int
        
        @root_validator(pre=pre)
        def convert_int(cls, values: dict):
            v = values.get("number")
            print(f"{pre = }, source {type(v) = }")
            if isinstance(v, str):
                v = int(v)
            values["number"] = v
            return values

    output = Convert(number=n).number
    print(f"{pre = }, output {type(output) = }")
    print("=" * 10)

test(pre=True, n="1")  # always str
test(pre=True, n=1)  # always str
test(pre=False, n="1")  # after convert to int
test(pre=False, n=1)

In short, because you used pre=True, the chain was transformed into the specified content according to your specification.

However, afterwards, Pydantic attempts to convert the chain into your specified _RegexExpandChain | _LLMExpandChain | None, coercing from left to right and returning the first successful object.
Since both _RegexExpandChain and _LLMExpandChain are subclasses of Chain, it will successfully convert to _RegexExpandChain, resulting in unexpected outcomes (as explained in the Pydantic documentation).

@luoling1993
Copy link
Author

But why does it return exactly what I want when I set up chain: Chain | None = None

@gbaian10
Copy link
Contributor

gbaian10 commented Aug 21, 2024

But why does it return exactly what I want when I set up chain: Chain | None = None

@luoling1993

Because they are all Chains, they already meet the requirements during Pydantic validation and won't undergo additional conversion.

However, if it's _RegexExpandChain | _LLMExpandChain | None, it will attempt to convert _LLMExpandChain into _RegexExpandChain.

from pydantic.v1 import BaseModel

class Convert(BaseModel):
    v: int | str
    # Pydantic first attempts to convert the input to an int;
    # if it fails, it tries to convert it to a str,
    # and if that also fails, it raises an error.

print(type(Convert(v="1").v))  # is int, because it can convert to int
print(type(Convert(v="a").v))  # is str, because it can't convert to int
print(type(Convert(v=[1]).v))  # raise ValidationError, because it can't convert to int or str  

In this example, even though "1" already satisfies the condition of int | str, under the handling of Union, it will still attempt type conversion.

@gbaian10
Copy link
Contributor

Another option is to set smart_union = True, where Pydantic will validate all results rather than attempting them one by one, returning early upon the first success.

class ExpandChain(Chain):
    expand_type: Literal["regex", "llm"]

    chain: _RegexExpandChain  | _LLMExpandChain | None = None  # will always use _RegexExpandChain
    # chain: Chain | None = None  # will be correct
    
    class Config:
        smart_union = True

    @root_validator(pre=True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🤖:bug Related to a bug, vulnerability, unexpected error with an existing feature Ɑ: core Related to langchain-core
Projects
None yet
Development

No branches or pull requests

2 participants