Skip to content

Commit

Permalink
Fix cont_cat_split for multi-label classification (#2759)
Browse files Browse the repository at this point in the history
* Fix cont_cat_split

* Fix cont_cat_split

* Add tests

* Sync lib with notebooks

* Re-run GitHub Actions checks
  • Loading branch information
albertvillanova committed Sep 9, 2020
1 parent 3414477 commit 2b2b970
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
2 changes: 1 addition & 1 deletion fastai/tabular/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def cont_cat_split(df, max_card=20, dep_var=None):
"Helper function that returns column names of cont and cat variables from given `df`."
cont_names, cat_names = [], []
for label in df:
if label == dep_var: continue
if label in L(dep_var): continue
if df[label].dtype == int and df[label].unique().shape[0] > max_card or df[label].dtype == float:
cont_names.append(label)
else: cat_names.append(label)
Expand Down
28 changes: 27 additions & 1 deletion nbs/40_tabular.core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -418,13 +418,39 @@
" \"Helper function that returns column names of cont and cat variables from given `df`.\"\n",
" cont_names, cat_names = [], []\n",
" for label in df:\n",
" if label == dep_var: continue\n",
" if label in L(dep_var): continue\n",
" if df[label].dtype == int and df[label].unique().shape[0] > max_card or df[label].dtype == float:\n",
" cont_names.append(label)\n",
" else: cat_names.append(label)\n",
" return cont_names, cat_names"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = pd.DataFrame({'cat1': [1, 2, 3, 4], 'cont1': [1., 2., 3., 2.], 'cat2': ['a', 'b', 'b', 'a'], \n",
" 'y1': [1, 0, 1, 0], 'y2': [1, 1, 1, 0]})\n",
"\n",
"# Test all columns\n",
"cont, cat = cont_cat_split(df)\n",
"test_eq((cont, cat), (['cont1'], ['cat1', 'cat2', 'y1', 'y2']))\n",
"\n",
"# Test exclusion of dependent variable\n",
"cont, cat = cont_cat_split(df, dep_var='y1')\n",
"test_eq((cont, cat), (['cont1'], ['cat1', 'cat2', 'y2']))\n",
"\n",
"# Test exclusion of multi-label dependent variables\n",
"cont, cat = cont_cat_split(df, dep_var=['y1', 'y2'])\n",
"test_eq((cont, cat), (['cont1'], ['cat1', 'cat2']))\n",
"\n",
"# Test maximal cardinality bound for int variable\n",
"cont, cat = cont_cat_split(df, max_card=2, dep_var=['y1', 'y2'])\n",
"test_eq((cont, cat), (['cat1', 'cont1'], ['cat2']))"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit 2b2b970

Please sign in to comment.