diff --git a/libs/langchain/langchain/document_loaders/obsidian.py b/libs/langchain/langchain/document_loaders/obsidian.py index 85a64860d11dd1..7d6bf6e89e3e9c 100644 --- a/libs/langchain/langchain/document_loaders/obsidian.py +++ b/libs/langchain/langchain/document_loaders/obsidian.py @@ -1,7 +1,8 @@ +import functools import logging import re from pathlib import Path -from typing import List +from typing import Any, Dict, List import yaml from langchain_core.documents import Document @@ -15,6 +16,7 @@ class ObsidianLoader(BaseLoader): """Load `Obsidian` files from directory.""" FRONT_MATTER_REGEX = re.compile(r"^---\n(.*?)\n---\n", re.DOTALL) + TEMPLATE_VARIABLE_REGEX = re.compile(r"{{(.*?)}}", re.DOTALL) TAG_REGEX = re.compile(r"[^\S\/]#([a-zA-Z_]+[-_/\w]*)") DATAVIEW_LINE_REGEX = re.compile(r"^\s*(\w+)::\s*(.*)$", re.MULTILINE) DATAVIEW_INLINE_BRACKET_REGEX = re.compile(r"\[(\w+)::\s*(.*)\]", re.MULTILINE) @@ -35,6 +37,27 @@ def __init__( self.encoding = encoding self.collect_metadata = collect_metadata + def _replace_template_var( + self, placeholders: Dict[str, str], match: re.Match + ) -> str: + """Replace a template variable with a placeholder.""" + placeholder = f"__TEMPLATE_VAR_{len(placeholders)}__" + placeholders[placeholder] = match.group(1) + return placeholder + + def _restore_template_vars(self, obj: Any, placeholders: Dict[str, str]) -> Any: + """Restore template variables replaced with placeholders to original values.""" + if isinstance(obj, str): + for placeholder, value in placeholders.items(): + obj = obj.replace(placeholder, f"{{{{{value}}}}}") + elif isinstance(obj, dict): + for key, value in obj.items(): + obj[key] = self._restore_template_vars(value, placeholders) + elif isinstance(obj, list): + for i, item in enumerate(obj): + obj[i] = self._restore_template_vars(item, placeholders) + return obj + def _parse_front_matter(self, content: str) -> dict: """Parse front matter metadata from the content and return it as a dict.""" if not self.collect_metadata: @@ -44,8 +67,17 @@ def _parse_front_matter(self, content: str) -> dict: if not match: return {} + placeholders: Dict[str, str] = {} + replace_template_var = functools.partial( + self._replace_template_var, placeholders + ) + front_matter_text = self.TEMPLATE_VARIABLE_REGEX.sub( + replace_template_var, match.group(1) + ) + try: - front_matter = yaml.safe_load(match.group(1)) + front_matter = yaml.safe_load(front_matter_text) + front_matter = self._restore_template_vars(front_matter, placeholders) # If tags are a string, split them into a list if "tags" in front_matter and isinstance(front_matter["tags"], str): diff --git a/libs/langchain/tests/unit_tests/document_loaders/sample_documents/obsidian/template_var_frontmatter.md b/libs/langchain/tests/unit_tests/document_loaders/sample_documents/obsidian/template_var_frontmatter.md new file mode 100644 index 00000000000000..7bab90737c31fc --- /dev/null +++ b/libs/langchain/tests/unit_tests/document_loaders/sample_documents/obsidian/template_var_frontmatter.md @@ -0,0 +1,12 @@ +--- +aString: {{var}} +anArray: +- element +- {{varElement}} +aDict: + dictId1: 'val' + dictId2: '{{varVal}}' +tags: [ 'tag', '{{varTag}}' ] +--- + +Frontmatter contains template variables. diff --git a/libs/langchain/tests/unit_tests/document_loaders/test_obsidian.py b/libs/langchain/tests/unit_tests/document_loaders/test_obsidian.py index 50f29d849e17ba..e25bf80199d82e 100644 --- a/libs/langchain/tests/unit_tests/document_loaders/test_obsidian.py +++ b/libs/langchain/tests/unit_tests/document_loaders/test_obsidian.py @@ -17,7 +17,7 @@ def test_page_content_loaded() -> None: """Verify that all docs have page_content""" - assert len(docs) == 5 + assert len(docs) == 6 assert all(doc.page_content for doc in docs) @@ -27,7 +27,7 @@ def test_disable_collect_metadata() -> None: str(OBSIDIAN_EXAMPLE_PATH), collect_metadata=False ) docs_wo = loader_without_metadata.load() - assert len(docs_wo) == 5 + assert len(docs_wo) == 6 assert all(doc.page_content for doc in docs_wo) assert all(set(doc.metadata) == STANDARD_METADATA_FIELDS for doc in docs_wo) @@ -45,6 +45,24 @@ def test_metadata_with_frontmatter() -> None: assert set(doc.metadata["tags"].split(",")) == {"journal/entry", "obsidian"} +def test_metadata_with_template_vars_in_frontmatter() -> None: + """Verify frontmatter fields with template variables are loaded.""" + doc = next( + doc for doc in docs if doc.metadata["source"] == "template_var_frontmatter.md" + ) + FRONTMATTER_FIELDS = { + "aString", + "anArray", + "aDict", + "tags", + } + assert set(doc.metadata) == FRONTMATTER_FIELDS | STANDARD_METADATA_FIELDS + assert doc.metadata["aString"] == "{{var}}" + assert doc.metadata["anArray"] == "['element', '{{varElement}}']" + assert doc.metadata["aDict"] == "{'dictId1': 'val', 'dictId2': '{{varVal}}'}" + assert set(doc.metadata["tags"].split(",")) == {"tag", "{{varTag}}"} + + def test_metadata_with_bad_frontmatter() -> None: """Verify a doc with non-yaml frontmatter.""" doc = next(doc for doc in docs if doc.metadata["source"] == "bad_frontmatter.md")