Skip to content

Commit

Permalink
refactor(querybot): remove unused imports and update function parameters
Browse files Browse the repository at this point in the history
- Removed unused imports from `querybot.py` including
  `BaseCallbackManager`, `StreamingStdOutCallbackHandler`, and
  `ChatOpenAI`.
- Updated `make_or_load_vector_index` function to take `service_context`
  as a parameter instead of creating it within the function.
- Added `service_context` as an attribute to the `QueryBot` class.
- Removed the `make_service_context` function as it is no longer needed.
  • Loading branch information
ericmjl committed Oct 30, 2023
1 parent af0df87 commit 935e3da
Showing 1 changed file with 3 additions and 22 deletions.
25 changes: 3 additions & 22 deletions llamabot/bot/querybot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
import tiktoken
from copy import deepcopy
from dotenv import load_dotenv
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from llama_index import (
Document,
Expand Down Expand Up @@ -116,7 +113,7 @@ def __init__(
# Override vector_index if doc_paths is specified.
if doc_paths is not None:
vector_index = make_or_load_vector_index(
doc_paths, chunk_sizes=chunk_sizes, use_cache=use_cache
service_context, doc_paths, chunk_sizes=chunk_sizes, use_cache=use_cache
)

# Set object attributes.
Expand All @@ -138,6 +135,7 @@ def __init__(
self.chunk_sizes = chunk_sizes
self.response_tokens = response_tokens
self.history_tokens = history_tokens
self.service_context = service_context

def __call__(
self,
Expand Down Expand Up @@ -255,23 +253,6 @@ def retrieve(
return source_nodes


def make_service_context():
"""Make a service context for the QueryBot.
:returns: A service context.
"""
chat = ChatOpenAI(
model_name=default_language_model(),
temperature=0.0,
streaming=True,
verbose=True,
callback_manager=BaseCallbackManager([StreamingStdOutCallbackHandler()]),
)
llm_predictor = LLMPredictor(llm=chat)
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor)
return service_context


# @validate_call
def load_index(persist_dir: Path, service_context: ServiceContext):
"""Load an index from disk.
Expand Down Expand Up @@ -341,6 +322,7 @@ def make_vector_index(


def make_or_load_vector_index(
service_context: ServiceContext,
doc_paths: List[Path] | List[str],
chunk_sizes: list[int] = [200, 500, 1000, 2000],
chunk_overlap: int = 0,
Expand Down Expand Up @@ -373,7 +355,6 @@ def make_or_load_vector_index(
# Make persist_dir based on the file hash's hexdigest.
persist_dir = CACHE_DIR / file_hash_hexdigest
persist_dir.mkdir(parents=True, exist_ok=True)
service_context = make_service_context()

# Step 2: Create the index's split documents.
split_docs = []
Expand Down

0 comments on commit 935e3da

Please sign in to comment.