Skip to content

Commit

Permalink
Testing and fixing remaining core functions
Browse files Browse the repository at this point in the history
  • Loading branch information
klieret committed Apr 19, 2019
1 parent 3a0430d commit 18adbaa
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 57 deletions.
26 changes: 16 additions & 10 deletions ankipandas/core_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,9 @@ def merge_card_info(db: sqlite3.Connection, df: pd.DataFrame, inplace=False,


def add_nids(db: sqlite3.Connection, df: pd.DataFrame, inplace=False,
id_column="cid"):
id_column="nid"):
""" Add note IDs to a dataframe that only contains card ids.
Example: ``add_nids(db, cards, id_column="nid")``
Args:
db: Database
Expand Down Expand Up @@ -361,9 +362,12 @@ def add_nids(db: sqlite3.Connection, df: pd.DataFrame, inplace=False,


def add_mids(db: sqlite3.Connection, df: pd.DataFrame, inplace=False,
id_column="cid"):
id_column="nid"):
""" Add note IDs to a dataframe that only contains note ids.
Example: ``add_mids(db, notes, id_column="id")``,
``add_mids(db, cards_with_merged_notes, id_column="nid")``.
Args:
db: Database
df: Dataframe to merge information into
Expand Down Expand Up @@ -397,7 +401,7 @@ def add_mids(db: sqlite3.Connection, df: pd.DataFrame, inplace=False,

def add_model_names(db: sqlite3.Connection, df: pd.DataFrame, inplace=False,
id_column="mid", new_column="mname"):
"""
""" Add model names to a dataframe that contains model IDs.
Args:
db: Database
Expand All @@ -409,16 +413,17 @@ def add_model_names(db: sqlite3.Connection, df: pd.DataFrame, inplace=False,
Returns:
New dataframe if inplace==True, else None
"""
if not id_column in df.columns:
if id_column not in df.columns:
raise ValueError(
"Could not find id column '{}'. You can specify a custom one using"
" the id_column option.".format(id_column)
)
if inplace:
df[new_column] = df[id_column].map(get_model_names(db))
df[new_column] = df[id_column].astype(str).map(get_model_names(db))
else:
df = copy.deepcopy(df)
add_model_names(db, df, inplace=True)
add_model_names(db, df, inplace=True, id_column=id_column,
new_column=new_column)
return df

# Cards
Expand All @@ -428,6 +433,7 @@ def add_model_names(db: sqlite3.Connection, df: pd.DataFrame, inplace=False,
def add_deck_names(db: sqlite3.Connection, df: pd.DataFrame, inplace=False,
id_column="did", new_column="dname"):
"""
Add deck names to a dataframe that contains deck IDs.
Args:
db: Database
Expand All @@ -439,16 +445,16 @@ def add_deck_names(db: sqlite3.Connection, df: pd.DataFrame, inplace=False,
Returns:
New dataframe if inplace==True, else None
"""
if not id_column in df.columns:
if id_column not in df.columns:
raise ValueError(
"Could not find id column '{}'. You can specify a custom one using"
" the id_column option.".format(id_column)
)
if inplace:
df[new_column] = df[id_column].map(get_deck_names(db))
df[new_column] = df[id_column].astype(str).map(get_deck_names(db))
else:
df = copy.deepcopy(df)
add_model_names(db, df, id_column=id_column, new_column=new_column,
add_deck_names(db, df, id_column=id_column, new_column=new_column,
inplace=True)
return df

Expand All @@ -470,7 +476,7 @@ def add_fields_as_columns(db: sqlite3.Connection, df: pd.DataFrame,
Returns:
New dataframe if inplace==True, else None
"""
if not id_column in df.columns:
if id_column not in df.columns:
raise ValueError(
"Could not find id column '{}'. You can specify a custom one using"
" the id_column option.".format(id_column)
Expand Down
169 changes: 122 additions & 47 deletions ankipandas/test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,77 +18,81 @@ def setUp(self):
"collection.anki2"
)

# todo: could also get this from anki_fields.txt
self.card_cols = [
'id',
'nid',
'did',
'ord',
'mod',
'usn',
'type',
'queue',
'due',
'ivl',
'factor',
'reps',
'lapses',
'left',
'odue',
'odid',
'flags',
'data'
]
self.note_cols = [
'id',
'guid',
'mid',
'mod',
'usn',
'tags',
'flds',
'sfld',
'csum',
'flags',
'data',
]
self.revlog_cols = [
'id',
'cid',
'usn',
'ease',
'ivl',
'lastIvl',
'factor',
'time',
'type'
]

def tearDown(self):
close_db(self.db)

def test_get_cards(self):
cards = get_cards(self.db)
self.assertEqual(len(cards), 3)
# todo: could also get this from anki_fields.txt
self.assertEqual(
list(sorted(cards.columns)),
sorted([
'id',
'nid',
'did',
'ord',
'mod',
'usn',
'type',
'queue',
'due',
'ivl',
'factor',
'reps',
'lapses',
'left',
'odue',
'odid',
'flags',
'data'
])
sorted(self.card_cols)
)

def test_get_notes(self):
notes = get_notes(self.db)
self.assertEqual(len(notes), 2)
self.assertEqual(
list(sorted(notes.columns)),
sorted([
'id',
'guid',
'mid',
'mod',
'usn',
'tags',
'flds',
'sfld',
'csum',
'flags',
'data',
])
sorted(self.note_cols)
)

def test_get_revlog(self):
revlog = get_revlog(self.db)
# todo assert length
self.assertEqual(
list(sorted(revlog.columns)),
sorted([
'id',
'cid',
'usn',
'ease',
'ivl',
'lastIvl',
'factor',
'time',
'type'
])
sorted(self.revlog_cols)
)

def test_get_deck_info(self):
dinfo = get_deck_info(self.db)
get_deck_info(self.db)
# todo

def test_get_deck_names(self):
Expand All @@ -99,7 +103,7 @@ def test_get_deck_names(self):
)

def test_get_model_info(self):
minfo = get_model_info(self.db)
get_model_info(self.db)
# todo

def test_get_model_names(self):
Expand All @@ -120,6 +124,77 @@ def test_get_field_names(self):
fnames["Basic"], ["Front", "Back"]
)

def test_merge_note_info(self):
cards = get_cards(self.db)
merged = merge_note_info(self.db, cards)
self.assertListEqual(
sorted(list(merged.columns)),
sorted(list(
set(self.card_cols) | set(self.note_cols) |
{"ndata", "nflags", "nmod", "nusn"} # clashes
))
)

def test_merge_card_info(self):
revlog = get_revlog(self.db)
merged = merge_card_info(self.db, revlog)
self.assertListEqual(
sorted(list(merged.columns)),
sorted(list(
set(self.revlog_cols) | set(self.card_cols) |
{"civl", "ctype", "cusn", "cid", "cfactor"} # clashes
))
)

def test_add_nids(self):
cards = get_cards(self.db)
cards = add_nids(self.db, cards)
self.assertIn("nid", list(cards.columns))
self.assertListEqual(
sorted(list(cards["nid"].unique())),
sorted(list(get_notes(self.db)["id"].unique()))
)

def test_add_mids(self):
notes = get_notes(self.db)
notes = add_mids(self.db, notes)
self.assertEqual(
len(notes["mid"].unique()),
2 # we don't have notesfor every model
)

def test_add_model_names(self):
notes = get_notes(self.db)
notes = add_mids(self.db, notes)
notes = add_model_names(self.db, notes)
self.assertEqual(
sorted(list(notes["mname"].unique())),
["Basic", 'Basic (and reversed card)']
)

def test_add_deck_names(self):
cards = get_cards(self.db)
cards = add_deck_names(self.db, cards)
self.assertEqual(
sorted(list(cards["dname"].unique())),
["Default"]
)

def test_add_fields_as_columns(self):
notes = get_notes(self.db)
notes = add_fields_as_columns(self.db, notes)
notes = add_model_names(self.db, notes)
self.assertEqual(
sorted(list(notes.columns)),
sorted(self.note_cols + ["mname", "Front", "Back"])
)
self.assertEqual(
list(notes.query("mname=='Basic'")["Front"].unique()),
["Basic: Front"]
)


class TestUtils(unittest.TestCase):
def test__replace_df_inplace(self):
df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
df_new = pd.DataFrame({"a": [1]})
Expand Down

0 comments on commit 18adbaa

Please sign in to comment.