Skip to content

Commit

Permalink
core[minor]: Adds an in-memory implementation of RecordManager (#13200)
Browse files Browse the repository at this point in the history
**Description:**
langchain offers three technologies to save data:
-
[vectorstore](https://python.langchain.com/docs/modules/data_connection/vectorstores/)
- [docstore](https://js.langchain.com/docs/api/schema/classes/Docstore)
- [record
manager](https://python.langchain.com/docs/modules/data_connection/indexing)

If you want to combine these technologies in a sample persistence
stategy you need a common implementation for each. `DocStore` propose
`InMemoryDocstore`.

We propose the class `MemoryRecordManager` to complete the system.

This is the prelude to another full-request, which needs a consistent
combination of persistence components.

**Tag maintainer:**
@baskaryan

**Twitter handle:**
@pprados

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
  • Loading branch information
2 people authored and hinthornw committed Jun 20, 2024
1 parent 505f489 commit 3cb9c4c
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 110 deletions.
3 changes: 2 additions & 1 deletion libs/core/langchain_core/indexing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
if it's unchanged.
"""
from langchain_core.indexing.api import IndexingResult, aindex, index
from langchain_core.indexing.base import RecordManager
from langchain_core.indexing.base import InMemoryRecordManager, RecordManager

__all__ = [
"aindex",
"index",
"IndexingResult",
"InMemoryRecordManager",
"RecordManager",
]
104 changes: 103 additions & 1 deletion libs/core/langchain_core/indexing/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import time
from abc import ABC, abstractmethod
from typing import List, Optional, Sequence
from typing import Dict, List, Optional, Sequence, TypedDict


class RecordManager(ABC):
Expand Down Expand Up @@ -215,3 +216,104 @@ async def adelete_keys(self, keys: Sequence[str]) -> None:
Args:
keys: A list of keys to delete.
"""


class _Record(TypedDict):
group_id: Optional[str]
updated_at: float


class InMemoryRecordManager(RecordManager):
"""An in-memory record manager for testing purposes."""

def __init__(self, namespace: str) -> None:
super().__init__(namespace)
# Each key points to a dictionary
# of {'group_id': group_id, 'updated_at': timestamp}
self.records: Dict[str, _Record] = {}
self.namespace = namespace

def create_schema(self) -> None:
"""In-memory schema creation is simply ensuring the structure is initialized."""

async def acreate_schema(self) -> None:
"""In-memory schema creation is simply ensuring the structure is initialized."""

def get_time(self) -> float:
"""Get the current server time as a high resolution timestamp!"""
return time.time()

async def aget_time(self) -> float:
"""Get the current server time as a high resolution timestamp!"""
return self.get_time()

def update(
self,
keys: Sequence[str],
*,
group_ids: Optional[Sequence[Optional[str]]] = None,
time_at_least: Optional[float] = None,
) -> None:
if group_ids and len(keys) != len(group_ids):
raise ValueError("Length of keys must match length of group_ids")
for index, key in enumerate(keys):
group_id = group_ids[index] if group_ids else None
if time_at_least and time_at_least > self.get_time():
raise ValueError("time_at_least must be in the past")
self.records[key] = {"group_id": group_id, "updated_at": self.get_time()}

async def aupdate(
self,
keys: Sequence[str],
*,
group_ids: Optional[Sequence[Optional[str]]] = None,
time_at_least: Optional[float] = None,
) -> None:
self.update(keys, group_ids=group_ids, time_at_least=time_at_least)

def exists(self, keys: Sequence[str]) -> List[bool]:
return [key in self.records for key in keys]

async def aexists(self, keys: Sequence[str]) -> List[bool]:
return self.exists(keys)

def list_keys(
self,
*,
before: Optional[float] = None,
after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
) -> List[str]:
result = []
for key, data in self.records.items():
if before and data["updated_at"] >= before:
continue
if after and data["updated_at"] <= after:
continue
if group_ids and data["group_id"] not in group_ids:
continue
result.append(key)
if limit:
return result[:limit]
return result

async def alist_keys(
self,
*,
before: Optional[float] = None,
after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
) -> List[str]:
return self.list_keys(
before=before, after=after, group_ids=group_ids, limit=limit
)

def delete_keys(self, keys: Sequence[str]) -> None:
for key in keys:
if key in self.records:
del self.records[key]

async def adelete_keys(self, keys: Sequence[str]) -> None:
self.delete_keys(keys)
105 changes: 0 additions & 105 deletions libs/core/tests/unit_tests/indexing/in_memory.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import pytest_asyncio

from tests.unit_tests.indexing.in_memory import InMemoryRecordManager
from langchain_core.indexing import InMemoryRecordManager


@pytest.fixture()
Expand Down
3 changes: 1 addition & 2 deletions libs/core/tests/unit_tests/indexing/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
from langchain_core.document_loaders.base import BaseLoader
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.indexing import aindex, index
from langchain_core.indexing import InMemoryRecordManager, aindex, index
from langchain_core.indexing.api import _abatch, _HashedDocument
from langchain_core.vectorstores import VST, VectorStore
from tests.unit_tests.indexing.in_memory import InMemoryRecordManager


class ToyLoader(BaseLoader):
Expand Down
1 change: 1 addition & 0 deletions libs/core/tests/unit_tests/indexing/test_public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ def test_all() -> None:
"aindex",
"index",
"IndexingResult",
"InMemoryRecordManager",
"RecordManager",
]

0 comments on commit 3cb9c4c

Please sign in to comment.