Skip to content

Commit

Permalink
Use set, inferred max token limits wherever chat models are used
Browse files Browse the repository at this point in the history
- User configured max tokens limits weren't being passed to
  `send_message_to_model_wrapper'
- One of the load offline model code paths wasn't reachable. Remove it
  to simplify code
- When max prompt size isn't set infer max tokens based on free VRAM
  on machine
- Use min of app configured max tokens, vram based max tokens and
  model context window
  • Loading branch information
debanjum committed Apr 20, 2024
1 parent 002cd14 commit 175169c
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 13 deletions.
4 changes: 3 additions & 1 deletion src/khoj/database/adapters/__init__.py
Expand Up @@ -777,7 +777,9 @@ def get_valid_conversation_config(user: KhojUser, conversation: Conversation):

if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline":
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(conversation_config.chat_model)
chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)

return conversation_config

Expand Down
5 changes: 1 addition & 4 deletions src/khoj/processor/conversation/offline/utils.py
Expand Up @@ -68,7 +68,4 @@ def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int:
"""Infer max prompt size based on device memory and max context window supported by the model"""
vram_based_n_ctx = int(get_device_memory() / 2e6) # based on heuristic
if configured_max_tokens:
return min(configured_max_tokens, model_context_window)
else:
return min(vram_based_n_ctx, model_context_window)
return min(configured_max_tokens, vram_based_n_ctx, model_context_window)
4 changes: 2 additions & 2 deletions src/khoj/processor/conversation/utils.py
Expand Up @@ -13,7 +13,7 @@

from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ClientApplication, KhojUser
from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
from khoj.utils.helpers import is_none_or_empty, merge_dicts

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -145,7 +145,7 @@ def generate_chatml_messages_with_context(
# Set max prompt size from user config or based on pre-configured for model and machine specs
if not max_prompt_size:
if loaded_model:
max_prompt_size = min(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf))
max_prompt_size = infer_max_tokens(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf))
else:
max_prompt_size = model_to_prompt_size.get(model_name, 2000)

Expand Down
7 changes: 1 addition & 6 deletions src/khoj/routers/helpers.py
Expand Up @@ -409,7 +409,7 @@ async def send_message_to_model_wrapper(
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
api_key = openai_chat_config.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message, system_message=system_message, model_name=chat_model
user_message=message, system_message=system_message, model_name=chat_model, max_prompt_size=max_tokens
)

openai_response = send_message_to_model(
Expand Down Expand Up @@ -457,11 +457,6 @@ def generate_chat_response(

conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
if conversation_config.model_type == "offline":
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)

loaded_model = state.offline_chat_processor_config.loaded_model
chat_response = converse_offline(
references=compiled_references,
Expand Down

0 comments on commit 175169c

Please sign in to comment.