Skip to content

Commit

Permalink
Fix and test merge_dfs
Browse files Browse the repository at this point in the history
  • Loading branch information
klieret committed Apr 19, 2019
1 parent b476319 commit 3a0430d
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 10 deletions.
33 changes: 24 additions & 9 deletions ankipandas/core_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,16 @@ def get_field_names(db: sqlite3.Connection):
# todo: inplace passible decorator


def _replace_inplace(df, df_new):
def _replace_df_inplace(df, df_new):
""" Replace dataframe 'in place'. """
df.drop(df.index, inplace=True)
for col in df_new.columns:
df[col] = df_new[col]
drop_cols = set(df.columns) - set(df_new.columns)
df.drop(drop_cols, axis=1, inplace=True)


def merge_dfs(df: pd.DataFrame, df_add: pd.DataFrame, id_df:str,
def merge_dfs(df: pd.DataFrame, df_add: pd.DataFrame, id_df: str,
inplace=False, id_add="id", prepend="",
prepend_clash_only=True, columns=None, drop_columns=None):
"""
Expand All @@ -230,6 +232,16 @@ def merge_dfs(df: pd.DataFrame, df_add: pd.DataFrame, id_df:str,
Returns:
New merged dataframe
"""
# Careful: Do not drop the id column until later (else we can't merge)
# Still, we want to remove as much as possible here, because it's probably
# better performing
if columns:
df_add = df_add.drop(
set(df_add.columns)-(set(columns) | {id_add}), axis=1
)
if drop_columns:
df_add = df_add.drop(set(drop_columns) - {id_add}, axis=1)
# Careful: Rename columns after dropping unwanted ones
if prepend_clash_only:
col_clash = set(df.columns) & set(df_add.columns)
rename_dict = {
Expand All @@ -240,15 +252,17 @@ def merge_dfs(df: pd.DataFrame, df_add: pd.DataFrame, id_df:str,
col: prepend + col for col in df_add.columns
}
df_add = df_add.rename(columns=rename_dict)
if columns:
columns = set(columns) | {id_add}
df_add.drop(set(df_add.columns)-columns, axis=1, inplace=True)
if drop_columns:
drop_columns = set(drop_columns) - {id_add}
df_add.drop(drop_columns, axis=1, inplace=True)
# Careful: Might have renamed id_add as well
if id_add in rename_dict:
id_add = rename_dict[id_add]
df_merge = df.merge(df_add, left_on=id_df, right_on=id_add)
# Now remove id_add if it was to be removed
# Careful: 'in' doesn't work with None
if (columns and id_add not in columns) or \
(drop_columns and id_add in drop_columns):
df_merge.drop(id_add, axis=1, inplace=True)
if inplace:
_replace_inplace(df, df_merge)
_replace_df_inplace(df, df_merge)
else:
return df_merge

Expand Down Expand Up @@ -461,6 +475,7 @@ def add_fields_as_columns(db: sqlite3.Connection, df: pd.DataFrame,
"Could not find id column '{}'. You can specify a custom one using"
" the id_column option.".format(id_column)
)
# fixme: What if one field column is one that is already in use?
mids = df["mid"].unique()
if inplace:
for mid in mids:
Expand Down
93 changes: 92 additions & 1 deletion ankipandas/test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
# std
import unittest
import pathlib
import copy

# ours
from ankipandas.core_functions import *
# for hidden
import ankipandas.core_functions as core_functions


class TestCoreFunctions(unittest.TestCase):
def setUp(self):
self.db = load_db(
pathlib.Path(__file__).parent / "data" / "few_basic_cards" / "collection.anki2"
pathlib.Path(__file__).parent / "data" / "few_basic_cards" /
"collection.anki2"
)

def tearDown(self):
Expand All @@ -20,6 +24,7 @@ def tearDown(self):
def test_get_cards(self):
cards = get_cards(self.db)
self.assertEqual(len(cards), 3)
# todo: could also get this from anki_fields.txt
self.assertEqual(
list(sorted(cards.columns)),
sorted([
Expand Down Expand Up @@ -97,6 +102,92 @@ def test_get_model_info(self):
minfo = get_model_info(self.db)
# todo

def test_get_model_names(self):
names = get_model_names(self.db)
self.assertIn("Basic", names.values())
self.assertIn("Cloze", names.values())
self.assertEqual(len(names), 5)

def test_get_field_names(self):
fnames = get_field_names(self.db)
mnames = get_model_names(self.db)
fnames = {
mnames[mid]: fnames[mid]
for mid in mnames
}
self.assertEqual(len(fnames), len(get_model_names(self.db)))
self.assertListEqual(
fnames["Basic"], ["Front", "Back"]
)

def test__replace_df_inplace(self):
df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
df_new = pd.DataFrame({"a": [1]})
core_functions._replace_df_inplace(df, df_new)
self.assertEqual(len(df), 1)
self.assertEqual(len(df.columns), 1)
self.assertListEqual(list(df["a"].values), [1])


class TestMergeDfs(unittest.TestCase):
def setUp(self):
self.df = pd.DataFrame({
"id_df": [1, 2, 3, 1, 1],
"clash": ["a", "b", "c", "a", "a"]
})
self.df_add = pd.DataFrame({
"id_add": [1, 2, 3],
"value": [4, 5, 6],
"drop": [7, 8, 9],
"ignore": [10, 11, 12],
"clash": [1, 1, 1]
})

def test_merge_dfs(self):
df_merged = merge_dfs(
self.df,
self.df_add,
id_df="id_df",
id_add="id_add",
prepend="_",
columns=["value", "drop", "clash"],
drop_columns=["id_add", "drop"]
)
self.assertListEqual(
sorted(list(df_merged.columns)),
["_clash", "clash", "id_df", "value"]
)
self.assertListEqual(sorted(list(df_merged["value"])), [4, 4, 4, 5, 6])

def test_merge_dfs_prepend_all(self):
df_merged = merge_dfs(
self.df,
self.df_add,
id_df="id_df",
id_add="id_add",
prepend="_",
prepend_clash_only=False
)
self.assertListEqual(
sorted(list(df_merged.columns)),
['_clash', '_drop', '_id_add', '_ignore', '_value', 'clash',
'id_df']
)

def test_merge_dfs_inplace(self):
df = copy.deepcopy(self.df)
merge_dfs(
df,
self.df_add,
id_df="id_df",
id_add="id_add",
inplace=True
)
self.assertListEqual(
sorted(list(df.columns)),
['clash_x', 'clash_y', 'drop', 'id_add', 'id_df', 'ignore', 'value']
)
self.assertListEqual(sorted(list(df["value"])), [4, 4, 4, 5, 6])


if __name__ == "__main__":
Expand Down

0 comments on commit 3a0430d

Please sign in to comment.