Skip to content

Commit

Permalink
feat: add doc_ids parameter to Table.get
Browse files Browse the repository at this point in the history
See #504
Closes #486
  • Loading branch information
keenborder786 committed May 20, 2023
1 parent cf311a3 commit 6f03ec6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
5 changes: 4 additions & 1 deletion tests/test_tinydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,10 @@ def test_get_ids(db: TinyDB):
assert db.get(doc_id=el.doc_id) == el
assert db.get(doc_id=float('NaN')) is None # type: ignore


def test_get_multiple_ids(db: TinyDB):
el = db.all()
assert db.get(doc_id=[x.doc_id for x in el]) == el

def test_get_invalid(db: TinyDB):
with pytest.raises(RuntimeError):
db.get()
Expand Down
29 changes: 21 additions & 8 deletions tinydb/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,30 +279,43 @@ def search(self, cond: QueryLike) -> List[Document]:
def get(
self,
cond: Optional[QueryLike] = None,
doc_id: Optional[int] = None,
) -> Optional[Document]:
doc_id: Optional[Union[int , List]] = None,
doc_ids: Optional[List] = None
) -> Optional[Union[Document , List[Document]]]:
"""
Get exactly one document specified by a query or a document ID.
However if muliple document IDs are given then returns all docu-
ments in a list.
Returns ``None`` if the document doesn't exist.
:param cond: the condition to check against
:param doc_id: the document's ID
:param doc_ids: the document's IDs(multiple)
:returns: the document or ``None``
"""

:returns: the document(s) or ``None``
"""
table = self._read_table()
if doc_id is not None:
# Retrieve a document specified by its ID
table = self._read_table()
raw_doc = table.get(str(doc_id), None)

if raw_doc is None:
return None

# Convert the raw data to the document class
return self.document_class(raw_doc, doc_id)
elif doc_ids is not None:
# Filter the table by extracting out all those documents which have doc id
# specified in the doc_id list.
set_doc_id = set(doc_ids) # Since Doc Ids will be unique, making it a set to make sure constant lookup
raw_docs = dict(filter(lambda item: int(item[0]) in set_doc_id, table.items()))
if raw_docs is None:
return None

## Now return the filtered documents in form of list
return list(map(lambda x:self.document_class(raw_docs[str(x)] , int(x)) , raw_docs.keys()))

elif cond is not None:
# Find a document specified by a query
# The trailing underscore in doc_id_ is needed so MyPy
Expand All @@ -318,7 +331,7 @@ def get(

return None

raise RuntimeError('You have to pass either cond or doc_id')
raise RuntimeError('You have to pass either cond or doc_id or doc_ids')

def contains(
self,
Expand Down

0 comments on commit 6f03ec6

Please sign in to comment.