Skip to content

Commit

Permalink
Fix Pandas Categorical FutureWarning
Browse files Browse the repository at this point in the history
  • Loading branch information
warner-benjamin committed Oct 14, 2023
1 parent 4dceef2 commit 28b4d2a
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 29 deletions.
8 changes: 4 additions & 4 deletions fastai/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def ColSplitter(col='is_valid', on=None):
def _inner(o):
assert isinstance(o, pd.DataFrame), "ColSplitter only works when your items are a pandas DataFrame"
c = o.iloc[:,col] if isinstance(col, int) else o[col]
if on is None: valid_idx = c.values.astype('bool')
if on is None: valid_idx = c.values.astype('bool')
elif is_listy(on): valid_idx = c.isin(on)
else: valid_idx = c == on
return IndexSplitter(mask2idxs(valid_idx))(o)
Expand Down Expand Up @@ -222,7 +222,7 @@ def __call__(self, o, **kwargs):
class CategoryMap(CollBase):
"Collection of categories with the reverse mapping in `o2i`"
def __init__(self, col, sort=True, add_na=False, strict=False):
if is_categorical_dtype(col):
if hasattr(col, 'dtype') and isinstance(col.dtype, CategoricalDtype):
items = L(col.cat.categories, use_list=True)
#Remove non-used categories while keeping order
if strict: items = L(o for o in items if o in col.unique())
Expand Down Expand Up @@ -256,7 +256,7 @@ def setups(self, dsets):
if self.vocab is None and dsets is not None: self.vocab = CategoryMap(dsets, sort=self.sort, add_na=self.add_na)
self.c = len(self.vocab)

def encodes(self, o):
def encodes(self, o):
try:
return TensorCategory(self.vocab.o2i[o])
except KeyError as e:
Expand All @@ -279,7 +279,7 @@ def setups(self, dsets):
for b in dsets: vals = vals.union(set(b))
self.vocab = CategoryMap(list(vals), add_na=self.add_na)

def encodes(self, o):
def encodes(self, o):
if not all(elem in self.vocab.o2i.keys() for elem in o):
diff = [elem for elem in o if elem not in self.vocab.o2i.keys()]
diff_str = "', '".join(diff)
Expand Down
2 changes: 1 addition & 1 deletion fastai/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

# External modules
import requests,yaml,matplotlib.pyplot as plt,pandas as pd,scipy
from pandas.api.types import is_categorical_dtype,is_numeric_dtype
from pandas.api.types import CategoricalDtype,is_numeric_dtype
from numpy import array,ndarray
from scipy import ndimage
from pdb import set_trace
Expand Down
2 changes: 1 addition & 1 deletion fastai/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def accumulate(self, learn):
c,t = self.get_correct_ngrams(pred, targ, i+1, max_n=self.vocab_sz)
if c == 0:
smooth_mteval *= 2
c = 1 / smooth_mteval # exp smoothing, method 3 from https://aclanthology.org/W14-3346/
c = 1 / smooth_mteval # exp smoothing, method 3 from http://acl2014.org/acl2014/W14-33/pdf/W14-3346.pdf
self.corrects[i] += c
self.counts[i] += t

Expand Down
4 changes: 2 additions & 2 deletions fastai/tabular/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def name(self): return f"{super().name} -- {getattr(self,'__stored_args__',{})}"

# %% ../../nbs/40_tabular.core.ipynb 58
def _apply_cats (voc, add, c):
if not is_categorical_dtype(c):
if not (hasattr(c, 'dtype') and isinstance(c.dtype, CategoricalDtype)):
return pd.Categorical(c, categories=voc[c.name][add:]).codes+add
return c.cat.codes+add #if is_categorical_dtype(c) else c.map(voc[c.name].o2i)
def _decode_cats(voc, c): return c.map(dict(enumerate(voc[c.name].items)))
Expand Down Expand Up @@ -346,7 +346,7 @@ def show_batch(x: Tabular, y, its, max_n=10, ctxs=None):
# %% ../../nbs/40_tabular.core.ipynb 94
@delegates()
class TabDataLoader(TfmdDL):
"A transformed `DataLoader` for Tabular data"
"A transformed `DataLoader` for Tabular data"
def __init__(self, dataset, bs=16, shuffle=False, after_batch=None, num_workers=0, **kwargs):
if after_batch is None: after_batch = L(TransformBlock().batch_tfms)+ReadTabBatch(dataset)
super().__init__(dataset, bs=bs, shuffle=shuffle, after_batch=after_batch, num_workers=num_workers, **kwargs)
Expand Down
12 changes: 6 additions & 6 deletions nbs/05_data.transforms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,8 @@
"metadata": {},
"outputs": [],
"source": [
"fnames = [path/'train/3/9932.png', path/'valid/7/7189.png', \n",
" path/'valid/7/7320.png', path/'train/7/9833.png', \n",
"fnames = [path/'train/3/9932.png', path/'valid/7/7189.png',\n",
" path/'valid/7/7320.png', path/'train/7/9833.png',\n",
" path/'train/3/7666.png', path/'valid/3/925.png',\n",
" path/'train/7/724.png', path/'valid/3/93055.png']\n",
"splitter = GrandparentSplitter()"
Expand Down Expand Up @@ -684,7 +684,7 @@
" def _inner(o):\n",
" assert isinstance(o, pd.DataFrame), \"ColSplitter only works when your items are a pandas DataFrame\"\n",
" c = o.iloc[:,col] if isinstance(col, int) else o[col]\n",
" if on is None: valid_idx = c.values.astype('bool') \n",
" if on is None: valid_idx = c.values.astype('bool')\n",
" elif is_listy(on): valid_idx = c.isin(on)\n",
" else: valid_idx = c == on\n",
" return IndexSplitter(mask2idxs(valid_idx))(o)\n",
Expand Down Expand Up @@ -981,7 +981,7 @@
"class CategoryMap(CollBase):\n",
" \"Collection of categories with the reverse mapping in `o2i`\"\n",
" def __init__(self, col, sort=True, add_na=False, strict=False):\n",
" if is_categorical_dtype(col):\n",
" if hasattr(col, 'dtype') and isinstance(col.dtype, CategoricalDtype):\n",
" items = L(col.cat.categories, use_list=True)\n",
" #Remove non-used categories while keeping order\n",
" if strict: items = L(o for o in items if o in col.unique())\n",
Expand Down Expand Up @@ -1082,7 +1082,7 @@
" if self.vocab is None and dsets is not None: self.vocab = CategoryMap(dsets, sort=self.sort, add_na=self.add_na)\n",
" self.c = len(self.vocab)\n",
"\n",
" def encodes(self, o): \n",
" def encodes(self, o):\n",
" try:\n",
" return TensorCategory(self.vocab.o2i[o])\n",
" except KeyError as e:\n",
Expand Down Expand Up @@ -1169,7 +1169,7 @@
" for b in dsets: vals = vals.union(set(b))\n",
" self.vocab = CategoryMap(list(vals), add_na=self.add_na)\n",
"\n",
" def encodes(self, o): \n",
" def encodes(self, o):\n",
" if not all(elem in self.vocab.o2i.keys() for elem in o):\n",
" diff = [elem for elem in o if elem not in self.vocab.o2i.keys()]\n",
" diff_str = \"', '\".join(diff)\n",
Expand Down
30 changes: 15 additions & 15 deletions nbs/40_tabular.core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@
"outputs": [],
"source": [
"#|hide\n",
"test_eq(df.columns, ['Year', 'Month', 'Week', 'Day', 'Dayofweek', 'Dayofyear', 'Is_month_end', 'Is_month_start', \n",
"test_eq(df.columns, ['Year', 'Month', 'Week', 'Day', 'Dayofweek', 'Dayofyear', 'Is_month_end', 'Is_month_start',\n",
" 'Is_quarter_end', 'Is_quarter_start', 'Is_year_end', 'Is_year_start', 'Elapsed'])\n",
"test_eq(df[df.Elapsed.isna()].shape,(1, 13))\n",
"\n",
Expand Down Expand Up @@ -385,8 +385,8 @@
"source": [
"#|hide\n",
"# Test Order of columns when date isn't in first position\n",
"test_eq(df.columns, ['f1', 'f2', 'f3', 'f4', 'Year', 'Month', 'Week', 'Day', \n",
" 'Dayofweek', 'Dayofyear', 'Is_month_end', 'Is_month_start', \n",
"test_eq(df.columns, ['f1', 'f2', 'f3', 'f4', 'Year', 'Month', 'Week', 'Day',\n",
" 'Dayofweek', 'Dayofyear', 'Is_month_end', 'Is_month_start',\n",
" 'Is_quarter_end', 'Is_quarter_start', 'Is_year_end', 'Is_year_start', 'Elapsed'])\n",
"\n",
"# Test that week dtype is consistent with other datepart fields\n",
Expand Down Expand Up @@ -582,9 +582,9 @@
"outputs": [],
"source": [
"# Example with simple numpy types\n",
"df = pd.DataFrame({'cat1': [1, 2, 3, 4], 'cont1': [1., 2., 3., 2.], 'cat2': ['a', 'b', 'b', 'a'], \n",
" 'i8': pd.Series([1, 2, 3, 4], dtype='int8'), \n",
" 'u8': pd.Series([1, 2, 3, 4], dtype='uint8'), \n",
"df = pd.DataFrame({'cat1': [1, 2, 3, 4], 'cont1': [1., 2., 3., 2.], 'cat2': ['a', 'b', 'b', 'a'],\n",
" 'i8': pd.Series([1, 2, 3, 4], dtype='int8'),\n",
" 'u8': pd.Series([1, 2, 3, 4], dtype='uint8'),\n",
" 'f16': pd.Series([1, 2, 3, 4], dtype='float16'),\n",
" 'y1': [1, 0, 1, 0], 'y2': [2, 1, 1, 0]})\n",
"cont_names, cat_names = cont_cat_split(df)"
Expand Down Expand Up @@ -692,7 +692,7 @@
"#|hide\n",
"cont, cat = cont_cat_split(df, max_card=0)\n",
"test_eq((cont, cat), (\n",
" ['ui32', 'i64', 'f16', 'd1_Year', 'd1_Month', 'd1_Week', 'd1_Day', 'd1_Dayofweek', 'd1_Dayofyear', 'd1_Elapsed'], \n",
" ['ui32', 'i64', 'f16', 'd1_Year', 'd1_Month', 'd1_Week', 'd1_Day', 'd1_Dayofweek', 'd1_Dayofyear', 'd1_Elapsed'],\n",
" ['cat1', 'd1_date', 'd1_Is_month_end', 'd1_Is_month_start', 'd1_Is_quarter_end', 'd1_Is_quarter_start', 'd1_Is_year_end', 'd1_Is_year_start']\n",
" ))"
]
Expand Down Expand Up @@ -1287,7 +1287,7 @@
"source": [
"#|export\n",
"def _apply_cats (voc, add, c):\n",
" if not is_categorical_dtype(c):\n",
" if not (hasattr(c, 'dtype') and isinstance(c.dtype, CategoricalDtype)):\n",
" return pd.Categorical(c, categories=voc[c.name][add:]).codes+add\n",
" return c.cat.codes+add #if is_categorical_dtype(c) else c.map(voc[c.name].o2i)\n",
"def _decode_cats(voc, c): return c.map(dict(enumerate(voc[c.name].items)))"
Expand Down Expand Up @@ -1743,7 +1743,7 @@
"outputs": [],
"source": [
"#|hide\n",
"fill1,fill2,fill3 = (FillMissing(fill_strategy=s) \n",
"fill1,fill2,fill3 = (FillMissing(fill_strategy=s)\n",
" for s in [FillStrategy.median, FillStrategy.constant, FillStrategy.mode])\n",
"df = pd.DataFrame({'a':[0,1,np.nan,1,2,3,4]})\n",
"df1 = df.copy(); df2 = df.copy()\n",
Expand All @@ -1768,7 +1768,7 @@
"outputs": [],
"source": [
"#|hide\n",
"fill = FillMissing() \n",
"fill = FillMissing()\n",
"df = pd.DataFrame({'a':[0,1,np.nan,1,2,3,4], 'b': [0,1,2,3,4,5,6]})\n",
"to = TabularPandas(df, fill, cont_names=['a', 'b'])\n",
"test_eq(fill.na_dict, {'a': 1.5})\n",
Expand All @@ -1785,7 +1785,7 @@
"outputs": [],
"source": [
"#|hide\n",
"fill = FillMissing() \n",
"fill = FillMissing()\n",
"df = pd.DataFrame({'a':[0,1,np.nan,1,2,3,4], 'b': [0,1,2,3,4,5,6]})\n",
"to = TabularPandas(df, fill, cont_names=['a', 'b'])\n",
"test_eq(hasattr(to.procs.fill_missing, 'to'), False)"
Expand Down Expand Up @@ -1934,7 +1934,7 @@
"#|export\n",
"@delegates()\n",
"class TabDataLoader(TfmdDL):\n",
" \"A transformed `DataLoader` for Tabular data\" \n",
" \"A transformed `DataLoader` for Tabular data\"\n",
" def __init__(self, dataset, bs=16, shuffle=False, after_batch=None, num_workers=0, **kwargs):\n",
" if after_batch is None: after_batch = L(TransformBlock().batch_tfms)+ReadTabBatch(dataset)\n",
" super().__init__(dataset, bs=bs, shuffle=shuffle, after_batch=after_batch, num_workers=num_workers, **kwargs)\n",
Expand Down Expand Up @@ -3642,12 +3642,12 @@
"outputs": [],
"source": [
"@MultiCategorize\n",
"def encodes(self, to:Tabular): \n",
"def encodes(self, to:Tabular):\n",
" #to.transform(to.y_names, partial(_apply_cats, {n: self.vocab for n in to.y_names}, 0))\n",
" return to\n",
" \n",
"\n",
"@MultiCategorize\n",
"def decodes(self, to:Tabular): \n",
"def decodes(self, to:Tabular):\n",
" #to.transform(to.y_names, partial(_decode_cats, {n: self.vocab for n in to.y_names}))\n",
" return to"
]
Expand Down

0 comments on commit 28b4d2a

Please sign in to comment.