Skip to content
This repository has been archived by the owner on Aug 18, 2020. It is now read-only.

Commit

Permalink
End of summary
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger committed Jan 27, 2020
1 parent 17cb7fe commit 6df649a
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 7 deletions.
17 changes: 16 additions & 1 deletion fastai2/data/block.py
Expand Up @@ -117,6 +117,19 @@ def _apply_pipeline(p, x):
raise e
return x

# Cell
from .load import _collate_types

def _find_fail_collate(s):
s = L(*s)
for x in s[0]:
if not isinstance(x, _collate_types): return f"{type(x).__name__} is not collatable"
for i in range_of(s[0]):
try: _ = default_collate(s.itemgot(i))
except:
shapes = [getattr(o[i], 'shape', None) for o in s]
return f"Could not collate the {i}-th members of your tuples because got the following shapes\n{','.join([str(s) for s in shapes])}"

# Cell
@patch
def summary(self: DataBlock, source, bs=4, **kwargs):
Expand Down Expand Up @@ -147,7 +160,9 @@ def summary(self: DataBlock, source, bs=4, **kwargs):
b = dls.train.create_batch(s)
b = retain_types(b, s[0] if is_listy(s) else s)
except Exception as e:
print("It's not possible to collate your items in a batch, make sure all parts of your samples are tensors of the same size")
print("Error! It's not possible to collate your items in a batch")
why = _find_fail_collate(s)
print("Make sure all parts of your samples are tensors of the same size" if why is None else why)
raise e

if len([f for f in dls.train.after_batch.fs if f.name != 'noop'])!=0:
Expand Down
6 changes: 4 additions & 2 deletions fastai2/data/core.py
Expand Up @@ -56,8 +56,10 @@ def _retain_dl(self,b):
def new(self, dataset=None, cls=None, **kwargs):
res = super().new(dataset, cls, do_setup=False, **kwargs)
if not hasattr(self, '_n_inp') or not hasattr(self, '_types'):
self._one_pass()
res._n_inp,res._types = self._n_inp,self._types
try:
self._one_pass()
res._n_inp,res._types = self._n_inp,self._types
except: print("Could not do one pass in your dataloader, there is something wrong in it")
return res

def before_iter(self):
Expand Down
6 changes: 4 additions & 2 deletions nbs/03_data.core.ipynb
Expand Up @@ -152,8 +152,10 @@
" def new(self, dataset=None, cls=None, **kwargs):\n",
" res = super().new(dataset, cls, do_setup=False, **kwargs)\n",
" if not hasattr(self, '_n_inp') or not hasattr(self, '_types'):\n",
" self._one_pass()\n",
" res._n_inp,res._types = self._n_inp,self._types\n",
" try: \n",
" self._one_pass()\n",
" res._n_inp,res._types = self._n_inp,self._types\n",
" except: print(\"Could not do one pass in your dataloader, there is something wrong in it\")\n",
" return res\n",
"\n",
" def before_iter(self):\n",
Expand Down
24 changes: 23 additions & 1 deletion nbs/06_data.block.ipynb
Expand Up @@ -404,6 +404,26 @@
" return x"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"from fastai2.data.load import _collate_types\n",
"\n",
"def _find_fail_collate(s):\n",
" s = L(*s)\n",
" for x in s[0]: \n",
" if not isinstance(x, _collate_types): return f\"{type(x).__name__} is not collatable\"\n",
" for i in range_of(s[0]):\n",
" try: _ = default_collate(s.itemgot(i))\n",
" except:\n",
" shapes = [getattr(o[i], 'shape', None) for o in s]\n",
" return f\"Could not collate the {i}-th members of your tuples because got the following shapes\\n{','.join([str(s) for s in shapes])}\""
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -440,7 +460,9 @@
" b = dls.train.create_batch(s)\n",
" b = retain_types(b, s[0] if is_listy(s) else s)\n",
" except Exception as e:\n",
" print(\"It's not possible to collate your items in a batch, make sure all parts of your samples are tensors of the same size\")\n",
" print(\"Error! It's not possible to collate your items in a batch\")\n",
" why = _find_fail_collate(s)\n",
" print(\"Make sure all parts of your samples are tensors of the same size\" if why is None else why)\n",
" raise e\n",
" \n",
" if len([f for f in dls.train.after_batch.fs if f.name != 'noop'])!=0:\n",
Expand Down
2 changes: 1 addition & 1 deletion nbs/50_datablock_examples.ipynb

Large diffs are not rendered by default.

0 comments on commit 6df649a

Please sign in to comment.