Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes error loading Obsidian templates #13888

Merged
merged 6 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions libs/langchain/langchain/document_loaders/obsidian.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import functools
import logging
import re
from pathlib import Path
from typing import List
from typing import Any, Dict, List

import yaml
from langchain_core.documents import Document
Expand All @@ -15,6 +16,7 @@ class ObsidianLoader(BaseLoader):
"""Load `Obsidian` files from directory."""

FRONT_MATTER_REGEX = re.compile(r"^---\n(.*?)\n---\n", re.DOTALL)
TEMPLATE_VARIABLE_REGEX = re.compile(r"{{(.*?)}}", re.DOTALL)
TAG_REGEX = re.compile(r"[^\S\/]#([a-zA-Z_]+[-_/\w]*)")
DATAVIEW_LINE_REGEX = re.compile(r"^\s*(\w+)::\s*(.*)$", re.MULTILINE)
DATAVIEW_INLINE_BRACKET_REGEX = re.compile(r"\[(\w+)::\s*(.*)\]", re.MULTILINE)
Expand All @@ -35,6 +37,27 @@ def __init__(
self.encoding = encoding
self.collect_metadata = collect_metadata

def _replace_template_var(
self, placeholders: Dict[str, str], match: re.Match
) -> str:
"""Replace a template variable with a placeholder."""
placeholder = f"__TEMPLATE_VAR_{len(placeholders)}__"
placeholders[placeholder] = match.group(1)
return placeholder

def _restore_template_vars(self, obj: Any, placeholders: Dict[str, str]) -> Any:
"""Restore template variables replaced with placeholders to original values."""
if isinstance(obj, str):
for placeholder, value in placeholders.items():
obj = obj.replace(placeholder, f"{{{{{value}}}}}")
elif isinstance(obj, dict):
for key, value in obj.items():
obj[key] = self._restore_template_vars(value, placeholders)
elif isinstance(obj, list):
for i, item in enumerate(obj):
obj[i] = self._restore_template_vars(item, placeholders)
return obj

def _parse_front_matter(self, content: str) -> dict:
"""Parse front matter metadata from the content and return it as a dict."""
if not self.collect_metadata:
Expand All @@ -44,8 +67,17 @@ def _parse_front_matter(self, content: str) -> dict:
if not match:
return {}

placeholders: Dict[str, str] = {}
replace_template_var = functools.partial(
self._replace_template_var, placeholders
)
front_matter_text = self.TEMPLATE_VARIABLE_REGEX.sub(
replace_template_var, match.group(1)
)

try:
front_matter = yaml.safe_load(match.group(1))
front_matter = yaml.safe_load(front_matter_text)
front_matter = self._restore_template_vars(front_matter, placeholders)

# If tags are a string, split them into a list
if "tags" in front_matter and isinstance(front_matter["tags"], str):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
---
aString: {{var}}
anArray:
- element
- {{varElement}}
aDict:
dictId1: 'val'
dictId2: '{{varVal}}'
tags: [ 'tag', '{{varTag}}' ]
---

Frontmatter contains template variables.
22 changes: 20 additions & 2 deletions libs/langchain/tests/unit_tests/document_loaders/test_obsidian.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

def test_page_content_loaded() -> None:
"""Verify that all docs have page_content"""
assert len(docs) == 5
assert len(docs) == 6
assert all(doc.page_content for doc in docs)


Expand All @@ -27,7 +27,7 @@ def test_disable_collect_metadata() -> None:
str(OBSIDIAN_EXAMPLE_PATH), collect_metadata=False
)
docs_wo = loader_without_metadata.load()
assert len(docs_wo) == 5
assert len(docs_wo) == 6
assert all(doc.page_content for doc in docs_wo)
assert all(set(doc.metadata) == STANDARD_METADATA_FIELDS for doc in docs_wo)

Expand All @@ -45,6 +45,24 @@ def test_metadata_with_frontmatter() -> None:
assert set(doc.metadata["tags"].split(",")) == {"journal/entry", "obsidian"}


def test_metadata_with_template_vars_in_frontmatter() -> None:
"""Verify frontmatter fields with template variables are loaded."""
doc = next(
doc for doc in docs if doc.metadata["source"] == "template_var_frontmatter.md"
)
FRONTMATTER_FIELDS = {
"aString",
"anArray",
"aDict",
"tags",
}
assert set(doc.metadata) == FRONTMATTER_FIELDS | STANDARD_METADATA_FIELDS
assert doc.metadata["aString"] == "{{var}}"
assert doc.metadata["anArray"] == "['element', '{{varElement}}']"
assert doc.metadata["aDict"] == "{'dictId1': 'val', 'dictId2': '{{varVal}}'}"
assert set(doc.metadata["tags"].split(",")) == {"tag", "{{varTag}}"}


def test_metadata_with_bad_frontmatter() -> None:
"""Verify a doc with non-yaml frontmatter."""
doc = next(doc for doc in docs if doc.metadata["source"] == "bad_frontmatter.md")
Expand Down
Loading