Skip to content

Commit

Permalink
Fix and test convenience functions
Browse files Browse the repository at this point in the history
  • Loading branch information
klieret committed Apr 19, 2019
1 parent c88d6e5 commit 7dd18e1
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 7 deletions.
5 changes: 2 additions & 3 deletions ankipandas/convenience_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pathlib
import pandas as pd
from functools import lru_cache
from typing import Iterable

# ours
import ankipandas.core_functions as apd
Expand Down Expand Up @@ -64,8 +63,8 @@ def load_cards(
if merge_notes:
apd.merge_note_info(db, df, inplace=True)
apd.add_model_names(db, df, inplace=True)
if expand_fields:
apd.add_fields_as_columns(db, df, inplace=True)
if expand_fields:
apd.add_fields_as_columns(db, df, inplace=True)
apd.close_db(db)
return df

Expand Down
48 changes: 45 additions & 3 deletions ankipandas/core_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def close_db(db):
# Basic getters
# ==============================================================================

def _get_table(db: sqlite3.Connection, table):
df = pd.read_sql_query("SELECT * FROM {}".format(table), db)
# print(df.columns)
# df.set_index("id", inplace=True)
return df


def get_cards(db: sqlite3.Connection):
"""
Get all cards as a dataframe.
Expand All @@ -56,7 +63,7 @@ def get_cards(db: sqlite3.Connection):
Returns:
pandas.DataFrame
"""
return pd.read_sql_query("SELECT * FROM cards ", db)
return _get_table(db, "cards")


def get_notes(db: sqlite3.Connection):
Expand All @@ -69,7 +76,7 @@ def get_notes(db: sqlite3.Connection):
Returns:
pandas.DataFrame
"""
return pd.read_sql_query("SELECT * FROM notes ", db)
return _get_table(db, "notes")


def get_revlog(db: sqlite3.Connection):
Expand All @@ -82,7 +89,7 @@ def get_revlog(db: sqlite3.Connection):
Returns:
pandas.DataFrame
"""
return pd.read_sql_query("SELECT * FROM revlog ", db)
return _get_table(db, "revlog")


@lru_cache(cache_size)
Expand All @@ -109,6 +116,41 @@ def get_info(db: sqlite3.Connection):
return ret


# Basic Setters
# ==============================================================================

def _write_table(db: sqlite3.Connection, df: pd.DataFrame, table: str,
mode: str, id_column="id") -> None:
"""
Args:
db: Database
df: The dataframe to write
table: Table to write to: 'notes', 'cards', 'revlog'
mode: 'update': Update only existing entries, 'append': Only append new
entries, but do not modify, 'replace': Append, modify and delete
Returns:
"""
if table not in ["notes", "cards", "revlog"]:
raise ValueError(
"Writing to table '{}' is not supported.".format(table)
)

df_old = pd.read_sql_query("SELECT * FROM {}".format(table), db)
old_indices = set(df_old[id_column])
new_indices = set(df[id_column])
if mode == "update":
indices = set(old_indices)
elif mode == "append":
indices = set(new_indices) - set(old_indices)
elif mode == "replace":
indices = set(new_indices)
else:
raise ValueError("Unknown mode '{}'.".format(mode))
# df = df[df[id_column]]

# Trivially derived getters
# ==============================================================================

Expand Down
18 changes: 18 additions & 0 deletions ankipandas/test/test_convenience.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,5 +175,23 @@ def test_load_notes_expand(self):
sorted(note_cols + ["mname", "Front", "Back"])
)

def test_load_cards(self):
cards = convenience.load_cards(self.path)
self.assertEqual(
sorted(list(cards.columns)),
sorted(list(set(
card_cols + note_cols + ["dname", "mname", "Front", "Back"] +
["ndata", "nflags", "nmod", "nusn"]
))) # clashes
)

def test_load_cards_nomerge(self):
cards = convenience.load_cards(self.path, merge_notes=False)
self.assertEqual(
sorted(list(cards.columns)),
sorted(card_cols + ["dname"])
)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion ankipandas/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.dev2
0.0.dev3

0 comments on commit 7dd18e1

Please sign in to comment.