Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions langchain/prompts/loading.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Load prompts from disk."""
import importlib
import json
import logging
from pathlib import Path
from typing import Union

Expand All @@ -12,17 +13,20 @@
from langchain.utilities.loading import try_load_from_hub

URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
logger = logging.getLogger(__file__)


def load_prompt_from_config(config: dict) -> BasePromptTemplate:
"""Get the right type from the config and load it accordingly."""
prompt_type = config.pop("_type", "prompt")
if prompt_type == "prompt":
return _load_prompt(config)
elif prompt_type == "few_shot":
return _load_few_shot_prompt(config)
else:
raise ValueError
"""Load prompt from Config Dict."""
if "_type" not in config:
logger.warning("No `_type` key found, defaulting to `prompt`.")
config_type = config.pop("_type", "prompt")

if config_type not in type_to_loader_dict:
raise ValueError(f"Loading {config_type} prompt not supported")

prompt_loader = type_to_loader_dict[config_type]
return prompt_loader(config)


def _load_template(var_name: str, config: dict) -> dict:
Expand Down Expand Up @@ -150,3 +154,10 @@ def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
raise ValueError(f"Got unsupported file type {file_path.suffix}")
# Load the prompt from the config now.
return load_prompt_from_config(config)


type_to_loader_dict = {
"prompt": _load_prompt,
"few_shot": _load_few_shot_prompt,
# "few_shot_with_templates": _load_few_shot_with_templates_prompt,
}