Skip to content

Commit

Permalink
Preliminary implementation in classes; fixes to add_fields_as_columns
Browse files Browse the repository at this point in the history
  • Loading branch information
klieret committed Apr 21, 2019
1 parent 6d60ff3 commit c05429f
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 13 deletions.
1 change: 1 addition & 0 deletions ankipandas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from ankipandas.core_functions import *
from ankipandas.convenience_functions import *
from ankipandas.ankidf import *
170 changes: 170 additions & 0 deletions ankipandas/ankidf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#!/usr/bin/env python3

# std
import sqlite3

# 3rd
import pandas as pd

# ours
import ankipandas.convenience_functions as convenience
import ankipandas.core_functions as core


# todo: inplace == false as default
class AnkiDataFrame(pd.DataFrame):
_attributes = ("db", "db_path", "table")

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if len(args) == 1 and isinstance(args[0], AnkiDataFrame):
args[0]._copy_attrs(self)
self.db = None
self.db_path = None
self.table = None

def _load_db(self, path):
self.db = core.load_db(path)
self.db_path = path

def _get_table(self, path, table):
if not path:
path = self.db_path
self._load_db(path)
table = core._get_table(self.db, table)
core._replace_df_inplace(self, table)
self.table = table

@property
def nid_column(self):
if self.table == "notes":
return "id"
else:
return "nid"

@property
def cid_column(self):
if self.table == "cards":
return "id"
else:
return "cid"

@property
def _constructor(self):
def __constructor(*args, **kw):
df = self.__class__(*args, **kw)
self._copy_attrs(df)
return df
return __constructor

def _copy_attrs(self, df):
for attr in self._attributes:
df.__dict__[attr] = getattr(self, attr, None)

def merge_note_info(self, columns=None, drop_columns=None, prepend="n",
prepend_clash_only=True, nid_column="nid"):
core.merge_note_info(
db=self.db,
df=self,
inplace=True,
columns=columns,
drop_columns=drop_columns,
nid_column=nid_column,
prepend=prepend,
prepend_clash_only=prepend_clash_only
)

def merge_card_info(self, columns=None, drop_columns=None, prepend="c",
prepend_clash_only=True, cid_column="cid"):
core.merge_card_info(
db=self.db,
df=self,
inplace=True,
prepend=prepend,
prepend_clash_only=prepend_clash_only,
columns=columns,
drop_columns=drop_columns,
cid_column=cid_column
)

def add_nids(self, cid_column=None):
if not cid_column:
cid_column = self.cid_column
core.add_nids(
db=self.db,
df=self,
inplace=True,
cid_column=cid_column
)

def add_mids(self, nid_column=None):
if not nid_column:
nid_column = self.nid_column
# Todo: Perhaps call add_nids, if nid column not found
core.add_mids(
db=self.db,
df=self,
inplace=True,
nid_column=nid_column
)

def add_model_names(self, mid_column="mid", new_column="mname"):
# Todo: Perhaps call add_mids, if nid column not found
core.add_model_names(
db=self.db,
df=self.df,
inplace=True,
mid_column=mid_column,
new_column=new_column
)

def add_deck_names(self, new_column="dname", did_column="did"):
core.add_deck_names(
self.db,
self,
inplace=True,
did_column=did_column,
new_column=new_column
)

def add_fields_as_columns(self, mid_column="mid", prepend="",
flds_column="flds"):
core.add_fields_as_columns(
db=self.db,
df=self,
inplace=True,
mid_column=mid_column,
prepend=prepend,
flds_column=flds_column
)

def fields_as_columns_to_flds(self, mid_column="mid", prepended="",
drop=False):
core.fields_as_columns_to_flds(
db=self.db,
df=self,
mid_column=mid_column,
prepended=prepended,
drop=drop
)


class Cards(AnkiDataFrame):
def __init__(self, *args, path=None, **kwargs):
super().__init__(*args, **kwargs)
if len(args) == 0 and len(kwargs) == 0 and path:
self._get_table(path, "cards")


class Notes(AnkiDataFrame):
def __init__(self, *args, path=None, **kwargs):
super().__init__(*args, **kwargs)
if len(args) == 0 and len(kwargs) == 0 and path:
self._get_table(path, "notes")


class Revlog(AnkiDataFrame):
def __init__(self, *args, path=None, **kwargs):
super().__init__(*args, **kwargs)
if len(args) == 0 and len(kwargs) == 0 and path:
self._get_table(path, "revlog")
2 changes: 1 addition & 1 deletion ankipandas/convenience_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def load_revs(
if merge_cards:
apd.merge_card_info(db, df, inplace=True)
if merge_notes:
apd.add_nids(db, df, nid_column="cid", inplace=True)
apd.add_nids(db, df, cid_column="cid", inplace=True)
apd.merge_note_info(db, df, inplace=True)
if expand_fields:
apd.add_fields_as_columns(db, df, inplace=True)
Expand Down
34 changes: 22 additions & 12 deletions ankipandas/core_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,13 @@ def get_field_names(db: sqlite3.Connection):

def _replace_df_inplace(df, df_new):
""" Replace dataframe 'in place'. """
df.drop(df.index, inplace=True)
if df.index.any():
df.drop(df.index, inplace=True)
for col in df_new.columns:
df[col] = df_new[col]
drop_cols = set(df.columns) - set(df_new.columns)
df.drop(drop_cols, axis=1, inplace=True)
if drop_cols:
df.drop(drop_cols, axis=1, inplace=True)


def merge_dfs(df: pd.DataFrame, df_add: pd.DataFrame, id_df: str,
Expand Down Expand Up @@ -397,15 +399,15 @@ def merge_card_info(db: sqlite3.Connection, df: pd.DataFrame, inplace=False,


def add_nids(db: sqlite3.Connection, df: pd.DataFrame, inplace=False,
nid_column="nid"):
cid_column="cid"):
""" Add note IDs to a dataframe that only contains card ids.
Example: ``add_nids(db, cards, id_column="nid")``
Args:
db: Database
df: Dataframe to merge information into
inplace: If False, return new dataframe, else update old one
nid_column: Column with card ID
cid_column: Column with card ID
Returns:
New dataframe if inplace==True, else None
Expand All @@ -418,7 +420,7 @@ def add_nids(db: sqlite3.Connection, df: pd.DataFrame, inplace=False,
return merge_dfs(
df=df,
df_add=get_cards(db),
id_df=nid_column,
id_df=cid_column,
inplace=inplace,
columns=["nid"],
id_add="id",
Expand Down Expand Up @@ -528,7 +530,8 @@ def add_deck_names(db: sqlite3.Connection, df: pd.DataFrame, inplace=False,


def add_fields_as_columns(db: sqlite3.Connection, df: pd.DataFrame,
inplace=False, mid_column="mid", prepend=""):
inplace=False, mid_column="mid", prepend="",
flds_column="flds"):
"""
Args:
Expand All @@ -537,6 +540,7 @@ def add_fields_as_columns(db: sqlite3.Connection, df: pd.DataFrame,
inplace: If False, return new dataframe, else update old one
mid_column: Column with model ID
prepend: Prepend string to all new column names
flds_column: Column that contains the joined fields
Returns:
New dataframe if inplace==True, else None
Expand All @@ -546,21 +550,27 @@ def add_fields_as_columns(db: sqlite3.Connection, df: pd.DataFrame,
"Could not find model id column '{}'. You can specify a custom one "
"using the mid_column option.".format(mid_column)
)
if flds_column not in df.columns:
raise ValueError(
"Could not find fields column '{}'. You can specify a custom one "
"using the flds_column option.".format(flds_column)
)
# fixme: What if one field column is one that is already in use?
if inplace:
mids = df["mid"].unique()
for mid in mids:
df_model = df[df["mid"] == mid]
fields = df_model["flds"].str.split("\x1f", expand=True)
df_model = df[df[mid_column] == mid]
fields = df_model[flds_column].str.split("\x1f", expand=True)
for ifield, field in enumerate(get_field_names(db)[str(mid)]):
df.loc[df["mid"] == mid, prepend + field] = fields[ifield]
df.loc[df[mid_column] == mid, prepend + field] = fields[ifield]
else:
df = copy.deepcopy(df)
add_fields_as_columns(db, df, mid_column=mid_column, prepend=prepend,
inplace=True)
inplace=True, flds_column=flds_column)
return df


# todo: docstring
# fixme: what if fields aren't found?
def fields_as_columns_to_flds(db: sqlite3.Connection, df: pd.DataFrame,
inplace=False, mid_column="mid", prepended="",
Expand All @@ -571,13 +581,13 @@ def fields_as_columns_to_flds(db: sqlite3.Connection, df: pd.DataFrame,
"using the mid_column option.".format(mid_column)
)
if inplace:
mids = df["mid"].unique()
mids = df[mid_column].unique()
to_drop = []
for mid in mids:
fields = get_field_names(db)[str(mid)]
if prepended:
fields = [prepended + field for field in fields]
df.loc[df["mid"] == mid, "flds"] = \
df.loc[df[mid_column] == mid, "flds"] = \
pd.Series(df[fields].values.tolist()).str.join("\x1f")
if drop:
# Careful: Do not delete the fields here yet, other models
Expand Down

0 comments on commit c05429f

Please sign in to comment.