Skip to content

Commit

Permalink
Fixes error loading Obsidian templates (#13888)
Browse files Browse the repository at this point in the history
- **Description:** Obsidian templates can include
[variables](https://help.obsidian.md/Plugins/Templates#Template+variables)
using double curly braces. `ObsidianLoader` uses PyYaml to parse the
frontmatter of documents. This parsing throws an error when encountering
variables' curly braces. This is avoided by temporarily substituting
safe strings before parsing.
  - **Issue:** #13887
  - **Tag maintainer:** @hwchase17
  • Loading branch information
ealt committed Dec 4, 2023
1 parent f6d68d7 commit e09b876
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 4 deletions.
36 changes: 34 additions & 2 deletions libs/langchain/langchain/document_loaders/obsidian.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import functools
import logging
import re
from pathlib import Path
from typing import List
from typing import Any, Dict, List

import yaml
from langchain_core.documents import Document
Expand All @@ -15,6 +16,7 @@ class ObsidianLoader(BaseLoader):
"""Load `Obsidian` files from directory."""

FRONT_MATTER_REGEX = re.compile(r"^---\n(.*?)\n---\n", re.DOTALL)
TEMPLATE_VARIABLE_REGEX = re.compile(r"{{(.*?)}}", re.DOTALL)
TAG_REGEX = re.compile(r"[^\S\/]#([a-zA-Z_]+[-_/\w]*)")
DATAVIEW_LINE_REGEX = re.compile(r"^\s*(\w+)::\s*(.*)$", re.MULTILINE)
DATAVIEW_INLINE_BRACKET_REGEX = re.compile(r"\[(\w+)::\s*(.*)\]", re.MULTILINE)
Expand All @@ -35,6 +37,27 @@ def __init__(
self.encoding = encoding
self.collect_metadata = collect_metadata

def _replace_template_var(
self, placeholders: Dict[str, str], match: re.Match
) -> str:
"""Replace a template variable with a placeholder."""
placeholder = f"__TEMPLATE_VAR_{len(placeholders)}__"
placeholders[placeholder] = match.group(1)
return placeholder

def _restore_template_vars(self, obj: Any, placeholders: Dict[str, str]) -> Any:
"""Restore template variables replaced with placeholders to original values."""
if isinstance(obj, str):
for placeholder, value in placeholders.items():
obj = obj.replace(placeholder, f"{{{{{value}}}}}")
elif isinstance(obj, dict):
for key, value in obj.items():
obj[key] = self._restore_template_vars(value, placeholders)
elif isinstance(obj, list):
for i, item in enumerate(obj):
obj[i] = self._restore_template_vars(item, placeholders)
return obj

def _parse_front_matter(self, content: str) -> dict:
"""Parse front matter metadata from the content and return it as a dict."""
if not self.collect_metadata:
Expand All @@ -44,8 +67,17 @@ def _parse_front_matter(self, content: str) -> dict:
if not match:
return {}

placeholders: Dict[str, str] = {}
replace_template_var = functools.partial(
self._replace_template_var, placeholders
)
front_matter_text = self.TEMPLATE_VARIABLE_REGEX.sub(
replace_template_var, match.group(1)
)

try:
front_matter = yaml.safe_load(match.group(1))
front_matter = yaml.safe_load(front_matter_text)
front_matter = self._restore_template_vars(front_matter, placeholders)

# If tags are a string, split them into a list
if "tags" in front_matter and isinstance(front_matter["tags"], str):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
---
aString: {{var}}
anArray:
- element
- {{varElement}}
aDict:
dictId1: 'val'
dictId2: '{{varVal}}'
tags: [ 'tag', '{{varTag}}' ]
---

Frontmatter contains template variables.
22 changes: 20 additions & 2 deletions libs/langchain/tests/unit_tests/document_loaders/test_obsidian.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

def test_page_content_loaded() -> None:
"""Verify that all docs have page_content"""
assert len(docs) == 5
assert len(docs) == 6
assert all(doc.page_content for doc in docs)


Expand All @@ -27,7 +27,7 @@ def test_disable_collect_metadata() -> None:
str(OBSIDIAN_EXAMPLE_PATH), collect_metadata=False
)
docs_wo = loader_without_metadata.load()
assert len(docs_wo) == 5
assert len(docs_wo) == 6
assert all(doc.page_content for doc in docs_wo)
assert all(set(doc.metadata) == STANDARD_METADATA_FIELDS for doc in docs_wo)

Expand All @@ -45,6 +45,24 @@ def test_metadata_with_frontmatter() -> None:
assert set(doc.metadata["tags"].split(",")) == {"journal/entry", "obsidian"}


def test_metadata_with_template_vars_in_frontmatter() -> None:
"""Verify frontmatter fields with template variables are loaded."""
doc = next(
doc for doc in docs if doc.metadata["source"] == "template_var_frontmatter.md"
)
FRONTMATTER_FIELDS = {
"aString",
"anArray",
"aDict",
"tags",
}
assert set(doc.metadata) == FRONTMATTER_FIELDS | STANDARD_METADATA_FIELDS
assert doc.metadata["aString"] == "{{var}}"
assert doc.metadata["anArray"] == "['element', '{{varElement}}']"
assert doc.metadata["aDict"] == "{'dictId1': 'val', 'dictId2': '{{varVal}}'}"
assert set(doc.metadata["tags"].split(",")) == {"tag", "{{varTag}}"}


def test_metadata_with_bad_frontmatter() -> None:
"""Verify a doc with non-yaml frontmatter."""
doc = next(doc for doc in docs if doc.metadata["source"] == "bad_frontmatter.md")
Expand Down

0 comments on commit e09b876

Please sign in to comment.