Skip to content

Commit

Permalink
update grouped splitter
Browse files Browse the repository at this point in the history
  • Loading branch information
YSaxon committed Sep 16, 2020
1 parent 8b1ecf4 commit 3b9c7fe
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 15 deletions.
1 change: 1 addition & 0 deletions fastai/_nbdev.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@
"FileSplitter": "05_data.transforms.ipynb",
"ColSplitter": "05_data.transforms.ipynb",
"RandomSubsetSplitter": "05_data.transforms.ipynb",
"GroupedSplitter": "05_data.transforms.ipynb",
"parent_label": "05_data.transforms.ipynb",
"RegexLabeller": "05_data.transforms.ipynb",
"ColReader": "05_data.transforms.ipynb",
Expand Down
33 changes: 29 additions & 4 deletions fastai/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

__all__ = ['get_files', 'FileGetter', 'image_extensions', 'get_image_files', 'ImageGetter', 'get_text_files',
'ItemGetter', 'AttrGetter', 'RandomSplitter', 'TrainTestSplitter', 'IndexSplitter', 'GrandparentSplitter',
'FuncSplitter', 'MaskSplitter', 'FileSplitter', 'ColSplitter', 'RandomSubsetSplitter', 'parent_label',
'RegexLabeller', 'ColReader', 'CategoryMap', 'Categorize', 'Category', 'MultiCategorize', 'MultiCategory',
'OneHotEncode', 'EncodedMultiCategorize', 'RegressionSetup', 'get_c', 'ToTensor', 'IntToFloatTensor',
'broadcast_vec', 'Normalize']
'FuncSplitter', 'MaskSplitter', 'FileSplitter', 'ColSplitter', 'RandomSubsetSplitter', 'GroupedSplitter',
'parent_label', 'RegexLabeller', 'ColReader', 'CategoryMap', 'Categorize', 'Category', 'MultiCategorize',
'MultiCategory', 'OneHotEncode', 'EncodedMultiCategorize', 'RegressionSetup', 'get_c', 'ToTensor',
'IntToFloatTensor', 'broadcast_vec', 'Normalize']

# Cell
from ..torch_basics import *
Expand Down Expand Up @@ -164,6 +164,31 @@ def _inner(o):
return idxs[:train_len],idxs[train_len:train_len+valid_len]
return _inner

# Cell
def GroupedSplitter(groupkey,valid_pct=0.2, seed=None):
"Split `items` between train/val with `valid_pct` randomly, ensuring that subgroups are not split between sets. Groups are defined by a group key extractor function, or by a colname if `o` is a DataFrame"
def _inner(o):
if callable(groupkey):
ids=pd.DataFrame(o)
ids['group_keys']=ids.apply(groupkey)
keycol='group_keys'
else:
assert isinstance(o, pd.DataFrame), "`groupkey` can be a colname if `o` is a DataFrame, otherwise `groupkey` must be a function item->key that can extract a groupkey from an item in `o`"
assert groupkey in o, "`groupkey` is not a colname in the DataFrame `o`"
keycol=groupkey
ids=o
gk=ids.groupby(keycol).count()
shuffled_gk=gk.sample(frac=1,random_state=seed)
cumsum=shuffled_gk.cumsum()
desired_valid=len(o)*valid_pct
abs_diff=abs(cumsum-desired_valid)
valid_rows=abs_diff.iloc[:,0].argmin()+1
shuffled_gk['is_valid']=([True] * valid_rows +
[False]*(len(shuffled_gk) - valid_rows))
split_df=ids.join(shuffled_gk.loc[:,'is_valid'],on=keycol)
return ColSplitter()(split_df)
return _inner

# Cell
def parent_label(o):
"Label `item` with the parent folder name."
Expand Down
53 changes: 42 additions & 11 deletions nbs/05_data.transforms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -662,32 +662,39 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 500,
"metadata": {},
"outputs": [],
"source": [
"# export\n",
"def GroupedSplitter(groupkey_extractor,valid_pct=0.2, seed=None):\n",
" \"Split `items` between train/val with `valid_pct` randomly, ensuring that subgroups are not split between sets. Groups are defined by a group key extractor function\"\n",
"def GroupedSplitter(groupkey,valid_pct=0.2, seed=None):\n",
" \"Split `items` between train/val with `valid_pct` randomly, ensuring that subgroups are not split between sets. Groups are defined by a group key extractor function, or by a colname if `o` is a DataFrame\"\n",
" def _inner(o):\n",
" ids=pd.DataFrame(o)\n",
" ids['keys']=ids.apply(key_f)\n",
" gk=ids.groupby('keys').count()\n",
" if callable(groupkey):\n",
" ids=pd.DataFrame(o)\n",
" ids['group_keys']=ids.apply(groupkey)\n",
" keycol='group_keys'\n",
" else:\n",
" assert isinstance(o, pd.DataFrame), \"`groupkey` can be a colname if `o` is a DataFrame, otherwise `groupkey` must be a function item->key that can extract a groupkey from an item in `o`\"\n",
" assert groupkey in o, \"`groupkey` is not a colname in the DataFrame `o`\"\n",
" keycol=groupkey\n",
" ids=o\n",
" gk=ids.groupby(keycol).count()\n",
" shuffled_gk=gk.sample(frac=1,random_state=seed)\n",
" cumsum=shuffled_gk.cumsum()\n",
" desired_valid=len(o)*valid_pct\n",
" abs_diff=abs(cumsum-desired_valid)\n",
" max_row_for_validation=abs_diff.iloc[:,0].argmin()\n",
" shuffled_gk['is_valid']=False\n",
" shuffled_gk.iloc[:max_row_for_validation+1,1]=True\n",
" split_df=ids.join(shuffled_gk['is_valid'],on='keys')\n",
" valid_rows=abs_diff.iloc[:,0].argmin()+1\n",
" shuffled_gk['is_valid']=([True] * valid_rows + \n",
" [False]*(len(shuffled_gk) - valid_rows))\n",
" split_df=ids.join(shuffled_gk.loc[:,'is_valid'],on=keycol)\n",
" return ColSplitter()(split_df)\n",
" return _inner"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 501,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -705,6 +712,30 @@
"test_eq(f(src)[0], trn)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"f = GroupedSplitter('keys',seed=41)\n",
"src = list(range(1000))\n",
"key_f=lambda x:x%10\n",
"src2=pd.DataFrame(src)\n",
"src2['keys']=src2.apply(key_f)\n",
"src2['conconfounding_col_1']='test'\n",
"src2.insert(0,'confounding_col_0','test')\n",
"trn,val=f(src2)\n",
"assert 0<len(trn)<len(src2)\n",
"assert all(o not in val for o in trn)\n",
"k_trn=np.unique([key_f(o) for o in trn])\n",
"k_val=np.unique([key_f(o) for o in val])\n",
"assert all(k not in k_val for k in k_trn)\n",
"test_eq(len(trn), len(src2)-len(val))\n",
"# # test random seed consistency\n",
"test_eq(f(src2)[0], trn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down

0 comments on commit 3b9c7fe

Please sign in to comment.