Skip to content

Commit

Permalink
[Fix] Find db: If specified search path is name of file
Browse files Browse the repository at this point in the history
  • Loading branch information
klieret committed May 3, 2019
1 parent d441e45 commit 8391977
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
11 changes: 9 additions & 2 deletions ankipandas/convenience_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def load_revs(

# todo: decorator messes up sphinx signature
@lru_cache(32)
def _find_database(search_path, maxdepth=6, filename="collection.anki2",
def _find_database(search_path, maxdepth=6,
filename="collection.anki2",
break_on_first=False, user=None):
"""
Like find_database but only for one search path at a time. Also doesn't
Expand All @@ -116,8 +117,14 @@ def _find_database(search_path, maxdepth=6, filename="collection.anki2",
Returns:
collection.defaultdict({user: [list of results]})
"""
search_path = pathlib.Path(search_path)
if not os.path.exists(str(search_path)):
return collections.defaultdict(list)
if search_path.is_file():
if search_path.name == filename:
return collections.defaultdict(
list, {search_path.parent.name: [search_path]}
)
found = collections.defaultdict(list)
for root, dirs, files in os.walk(str(search_path)):
if filename in files:
Expand Down Expand Up @@ -190,7 +197,7 @@ def find_database(
else:
if found:
break

if user:
if user not in found:
raise ValueError(
Expand Down
20 changes: 20 additions & 0 deletions ankipandas/test/test_convenience.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3

# std
import collections
import unittest
from pathlib import Path
import tempfile
Expand Down Expand Up @@ -64,6 +65,25 @@ def test__find_database(self):
b = sorted(map(str, self.dbs[d]))
self.assertListEqual(a, b)

def test__find_database_filename(self):
# If doesn't exist
self.assertEqual(
convenience._find_database(
Path("abc/myfilename.txt"), filename="myfilename.txt"
),
{}
)
tmpdir = tempfile.TemporaryDirectory()
dir_path = Path(tmpdir.name) / "myfolder"
file_path = dir_path / "myfilename.txt"
dir_path.mkdir()
file_path.touch()
self.assertEqual(
convenience._find_database(file_path, filename="myfilename.txt"),
collections.defaultdict(list, {"myfolder": [file_path]})
)
tmpdir.cleanup()

def test_find_database(self):
with self.assertRaises(ValueError):
convenience.find_database(
Expand Down

0 comments on commit 8391977

Please sign in to comment.