Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

on_retriever_end() not called with ConversationalRetrievalChain #7290

Closed
4 of 14 tasks
mssalvatore opened this issue Jul 6, 2023 · 12 comments
Closed
4 of 14 tasks

on_retriever_end() not called with ConversationalRetrievalChain #7290

mssalvatore opened this issue Jul 6, 2023 · 12 comments
Labels
🤖:bug Related to a bug, vulnerability, unexpected error with an existing feature

Comments

@mssalvatore
Copy link
Contributor

mssalvatore commented Jul 6, 2023

System Info

LangChain: v0.0.225
OS: Ubuntu 22.04

Who can help?

@agola11
@hwchase17

Information

  • The official example notebooks/scripts
  • My own modified scripts

Related Components

  • LLMs/Chat Models
  • Embedding Models
  • Prompts / Prompt Templates / Prompt Selectors
  • Output Parsers
  • Document Loaders
  • Vector Stores / Retrievers
  • Memory
  • Agents / Agent Executors
  • Tools / Toolkits
  • Chains
  • Callbacks/Tracing
  • Async

Reproduction

Code

import langchain
from chromadb.config import Settings
from langchain.callbacks.streaming_stdout import BaseCallbackHandler
from langchain.chains import ConversationalRetrievalChain
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import LlamaCpp
from langchain.memory import ConversationBufferMemory
from langchain.schema.document import Document
from langchain.vectorstores import Chroma

langchain.debug = True


class DocumentCallbackHandler(BaseCallbackHandler):
    def on_retriever_end(
        self,
        documents: Sequence[Document],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        print(f"on_retriever_end() CALLED with {len(documents)} documents")


def setup():
    llm = LlamaCpp(
        model_path="models/GPT4All-13B-snoozy.ggml.q5_1.bin",
        n_ctx=4096,
        n_batch=8192,
        callbacks=[],
        verbose=False,
        use_mlock=True,
        n_gpu_layers=60,
        n_threads=8,
    )

    embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
    db = Chroma(
        persist_directory="./db",
        embedding_function=embeddings,
        client_settings=Settings(
            chroma_db_impl="duckdb+parquet",
            persist_directory="./db",
            anonymized_telemetry=False,
        ),
    )

    retriever = db.as_retriever(search_kwargs={"k": 4})
    memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
    return ConversationalRetrievalChain.from_llm(
        llm=llm, retriever=retriever, memory=memory, callbacks=[DocumentCallbackHandler()]
    )


def main():
    qa = setup()
    while True:
        question = input("\nEnter your question: ")
        answer = qa(question)["answer"]
        print(f"\n> Answer: {answer}")


if __name__ == "__main__":
    main()

Output

ggml_init_cublas: found 1 CUDA devices:
  Device 0: Quadro RTX 6000
llama.cpp: loading model from models/GPT4All-13B-snoozy.ggml.q5_1.bin
llama_model_load_internal: format     = ggjt v2 (pre #1508)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 4096
llama_model_load_internal: n_embd     = 5120
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 40
llama_model_load_internal: n_layer    = 40
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 9 (mostly Q5_1)
llama_model_load_internal: n_ff       = 13824
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 13B
llama_model_load_internal: ggml ctx size =    0.09 MB
llama_model_load_internal: using CUDA for GPU acceleration
llama_model_load_internal: mem required  = 2165.28 MB (+ 1608.00 MB per state)
llama_model_load_internal: allocating batch_size x 1 MB = 512 MB VRAM for the scratch buffer
llama_model_load_internal: offloading 40 repeating layers to GPU
llama_model_load_internal: offloading non-repeating layers to GPU
llama_model_load_internal: offloading v cache to GPU
llama_model_load_internal: offloading k cache to GPU
llama_model_load_internal: offloaded 43/43 layers to GPU
llama_model_load_internal: total VRAM used: 11314 MB
....................................................................................................
llama_init_from_file: kv self size  = 3200.00 MB
AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | VSX = 0 | 

Enter your question: Should Hamlet end his life?
[chain/start] [1:chain:ConversationalRetrievalChain] Entering Chain run with input:
{
  "question": "Should Hamlet end his life?",
  "chat_history": []
}
[chain/start] [1:chain:ConversationalRetrievalChain > 3:chain:StuffDocumentsChain] Entering Chain run with input:
[inputs]
[chain/start] [1:chain:ConversationalRetrievalChain > 3:chain:StuffDocumentsChain > 4:chain:LLMChain] Entering Chain run with input:
{
  "question": "Should Hamlet end his life?",
  "context": "Enter Hamlet.\n\nEnter Hamlet.\n\nEnter Hamlet.\n\nHaply the seas, and countries different,\n    With variable objects, shall expel\n    This something-settled matter in his heart,\n    Whereon his brains still beating puts him thus\n    From fashion of himself. What think you on't?\n  Pol. It shall do well. But yet do I believe\n    The origin and commencement of his grief\n    Sprung from neglected love.- How now, Ophelia?\n    You need not tell us what Lord Hamlet said.\n    We heard it all.- My lord, do as you please;"
}
[llm/start] [1:chain:ConversationalRetrievalChain > 3:chain:StuffDocumentsChain > 4:chain:LLMChain > 5:llm:LlamaCpp] Entering LLM run with input:
{
  "prompts": [
    "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.\n\nEnter Hamlet.\n\nEnter Hamlet.\n\nEnter Hamlet.\n\nHaply the seas, and countries different,\n    With variable objects, shall expel\n    This something-settled matter in his heart,\n    Whereon his brains still beating puts him thus\n    From fashion of himself. What think you on't?\n  Pol. It shall do well. But yet do I believe\n    The origin and commencement of his grief\n    Sprung from neglected love.- How now, Ophelia?\n    You need not tell us what Lord Hamlet said.\n    We heard it all.- My lord, do as you please;\n\nQuestion: Should Hamlet end his life?\nHelpful Answer:"
  ]
}

llama_print_timings:        load time =  1100.49 ms
llama_print_timings:      sample time =    13.20 ms /    17 runs   (    0.78 ms per token)
llama_print_timings: prompt eval time =  1100.33 ms /   208 tokens (    5.29 ms per token)
llama_print_timings:        eval time =  1097.70 ms /    16 runs   (   68.61 ms per token)
llama_print_timings:       total time =  2270.30 ms
[llm/end] [1:chain:ConversationalRetrievalChain > 3:chain:StuffDocumentsChain > 4:chain:LLMChain > 5:llm:LlamaCpp] [2.27s] Exiting LLM run with output:
{
  "generations": [
    [
      {
        "text": " I'm sorry, I don't know the answer to that question.",
        "generation_info": null
      }
    ]
  ],
  "llm_output": null,
  "run": null
}
[chain/end] [1:chain:ConversationalRetrievalChain > 3:chain:StuffDocumentsChain > 4:chain:LLMChain] [2.27s] Exiting Chain run with output:
{
  "text": " I'm sorry, I don't know the answer to that question."
}
[chain/end] [1:chain:ConversationalRetrievalChain > 3:chain:StuffDocumentsChain] [2.27s] Exiting Chain run with output:
{
  "output_text": " I'm sorry, I don't know the answer to that question."
}
[chain/end] [1:chain:ConversationalRetrievalChain] [5.41s] Exiting Chain run with output:
{
  "answer": " I'm sorry, I don't know the answer to that question."
}

> Answer:  I'm sorry, I don't know the answer to that question.

Expected behavior

I expect the on_retriever_end() callback to be called immediately after documents are retrieved. I'm not sure what I'm doing wrong.

@dosubot dosubot bot added the 🤖:bug Related to a bug, vulnerability, unexpected error with an existing feature label Jul 6, 2023
@mssalvatore
Copy link
Contributor Author

mssalvatore commented Jul 11, 2023

I've reproduced similar behavior with a similar but simpler example:

from typing import Any, Optional
from uuid import UUID

import langchain
from langchain.callbacks.streaming_stdout import BaseCallbackHandler
from langchain.chains.llm import LLMChain
from langchain.llms import LlamaCpp

langchain.debug = True


class LLMTokenHandler(BaseCallbackHandler):
    def on_llm_new_token(
        self,
        token: str,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        print(f"on_llm_new_token() CALLED with {token}")


llm = LlamaCpp(
    model_path="models/GPT4All-13B-snoozy.ggml.q5_1.bin",
    n_ctx=4096,
    n_batch=8192,
    callbacks=[],
    verbose=False,
    use_mlock=True,
    n_gpu_layers=60,
    n_threads=8,
)

prompt_template = "What is the definition of {word}?"

llm_chain = LLMChain(
    llm=llm,
    prompt=langchain.PromptTemplate.from_template(prompt_template),
    callbacks=[LLMTokenHandler()],
)
llm_chain("befuddle")

I've done some investigating and I think what's happening is that callbacks passed into chains are not inheritable, so they are being dropped.

https://github.com/hwchase17/langchain/blob/9e067b8cc917f3b753d944ecf9e1ed080b153250/langchain/chains/llm.py#L102-L107

@agola11, @hwchase17 The LLM is being passed only run_manager.get_child(). It seems strange to me that if I pass a callback for on_llm_new_token() to the LLMChain it gets ignored. Since passing in a callback_manager is deprecated, I haven't found any way to make the handler I'm passing in inheritable. Am I just approaching this the wrong way?

@mssalvatore
Copy link
Contributor Author

@hwchase17 I'm happy to submit a fix, but I first need to understand a little more about the design choices and intent, and what an appropriate solution would be.

@kcho02
Copy link

kcho02 commented Aug 10, 2023

I am facing the same issue. Is this resolved by any chance?

@mssalvatore
Copy link
Contributor Author

@kcho02 To the best of my knowledge, no. You can pass the callback directly to the LLM, but if you use the same LLM object in two or more different chains, this may be undesirable.

@kcho02
Copy link

kcho02 commented Aug 10, 2023

I see. Thanks for the response. Until it's fixed, callback_manager seems to be the way to go for now

@erpic
Copy link

erpic commented Aug 30, 2023

I noticed from the doc https://python.langchain.com/docs/modules/callbacks/#where-to-pass-in-callbacks that the callback could be provided either when calling the constructor or when running a request.

I was also facing this issue when passing callbacks to the constructor but it works for me (on_retriever_end() gets called) when passing at run time.

So maybe, in the first example above try something like:

answer = qa(question, callbacks=[DocumentCallbackHandler()])["answer"]

@mssalvatore
Copy link
Contributor Author

@erpic It's a workaround, but requires me to couple the component that calls the chain to the callbacks.

@hwchase17 I'm happy to submit a fix, but I first need to understand a little more about the design choices and intent, and what an appropriate solution would be.

@pai4451
Copy link

pai4451 commented Oct 6, 2023

Hi @mssalvatore @erpic, I am facing a similar issue, how did you guys fix this issue? I try to pass at run time but it does not work when chat for multiple round

@mssalvatore
Copy link
Contributor Author

@pai4451 I worked round it by writing a wrapper around the LLM. It's not really an approach I can recommend.

@pai4451
Copy link

pai4451 commented Oct 10, 2023

@pai4451 I worked round it by writing a wrapper around the LLM. It's not really an approach I can recommend.

@mssalvatore I found an implementation that exactly matches my needs.
https://gist.github.com/jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68?permalink_comment_id=4654242#gistcomment-4654242

@MarcSkovMadsen
Copy link

MarcSkovMadsen commented Oct 20, 2023

I work around this issue by inserting the hack below.

def get_chain(callbacks):
    retriever = db.as_retriever(callbacks=callbacks)
    model = ChatOpenAI(callbacks=callbacks)

    def format_docs(docs):
        text = "\n\n".join([d.page_content for d in docs])
        return text

    def hack(docs):
        # https://github.com/langchain-ai/langchain/issues/7290
        for callback in callbacks:
            callback.on_retriever_end(docs, run_id=uuid4())
        return docs

    return (
        {"context": retriever | hack | format_docs, "question": RunnablePassthrough()}
        | prompt
        | model
    )

I use it for documenting how the PanelCallbackHandler and ChatInterface to be released in Panel 1.30 works. Dev Docs | Prod docs

Full Code
from uuid import uuid4

import requests

from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnablePassthrough
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma

import panel as pn

TEXT = "https://raw.githubusercontent.com/langchain-ai/langchain/master/docs/docs/modules/state_of_the_union.txt"

TEMPLATE = """Answer the question based only on the following context:

{context}

Question: {question}
"""

pn.extension(design="material")

prompt = ChatPromptTemplate.from_template(TEMPLATE)


@pn.cache
def get_vector_store():
    full_text = requests.get(TEXT).text
    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
    texts = text_splitter.split_text(full_text)
    embeddings = OpenAIEmbeddings()
    db = Chroma.from_texts(texts, embeddings)
    return db


db = get_vector_store()


def get_chain(callbacks):
    retriever = db.as_retriever(callbacks=callbacks)
    model = ChatOpenAI(callbacks=callbacks)

    def format_docs(docs):
        text = "\n\n".join([d.page_content for d in docs])
        return text

    def hack(docs):
        # https://github.com/langchain-ai/langchain/issues/7290
        for callback in callbacks:
            callback.on_retriever_end(docs, run_id=uuid4())
        return docs

    return (
        {"context": retriever | hack | format_docs, "question": RunnablePassthrough()}
        | prompt
        | model
    )


async def callback(contents, user, instance):
    callback_handler = pn.chat.langchain.PanelCallbackHandler(instance)
    chain = get_chain(callbacks=[callback_handler])
    await chain.ainvoke(contents)


pn.chat.ChatInterface(callback=callback).servable()

@marklysze
Copy link

Thanks, this hack worked for me...

@dosubot dosubot bot added the stale Issue has not had recent activity or appears to be solved. Stale issues will be automatically closed label Apr 18, 2024
@dosubot dosubot bot closed this as not planned Won't fix, can't repro, duplicate, stale Apr 25, 2024
@dosubot dosubot bot removed the stale Issue has not had recent activity or appears to be solved. Stale issues will be automatically closed label Apr 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🤖:bug Related to a bug, vulnerability, unexpected error with an existing feature
Projects
None yet
Development

No branches or pull requests

6 participants