Skip to content

Commit

Permalink
adjustable workers (#2721)
Browse files Browse the repository at this point in the history
Faster inference for text and tabular
  • Loading branch information
muellerzr committed Sep 6, 2020
1 parent 93b995d commit 89af725
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 12 deletions.
6 changes: 3 additions & 3 deletions fastai/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,11 @@ def validate(self, ds_idx=1, dl=None, cbs=None):

@delegates(GatherPredsCallback.__init__)
def get_preds(self, ds_idx=1, dl=None, with_input=False, with_decoded=False, with_loss=False, act=None,
inner=False, reorder=True, cbs=None, **kwargs):
if dl is None: dl = self.dls[ds_idx].new(shuffled=False, drop_last=False)
inner=False, reorder=True, cbs=None, n_workers=defaults.cpus, **kwargs):
if dl is None: dl = self.dls[ds_idx].new(shuffled=False, drop_last=False, num_workers=n_workers)
if reorder and hasattr(dl, 'get_idxs'):
idxs = dl.get_idxs()
dl = dl.new(get_idxs = _ConstantFunc(idxs))
dl = dl.new(get_idxs = _ConstantFunc(idxs), num_workers=n_workers)
cb = GatherPredsCallback(with_input=with_input, with_loss=with_loss, **kwargs)
ctx_mgrs = self.validation_context(cbs=L(cbs)+[cb], inner=inner)
if with_loss: ctx_mgrs.append(self.loss_not_reduced())
Expand Down
4 changes: 2 additions & 2 deletions fastai/tabular/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
@log_args(but_as=Learner.__init__)
class TabularLearner(Learner):
"`Learner` for tabular data"
def predict(self, row):
def predict(self, row, n_workers=defaults.cpus):
"Predict on a Pandas Series"
dl = self.dls.test_dl(row.to_frame().T)
dl = self.dls.test_dl(row.to_frame().T, num_workers=0)
dl.dataset.conts = dl.dataset.conts.astype(np.float32)
inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)
b = (*tuplify(inp),*tuplify(dec_preds))
Expand Down
2 changes: 1 addition & 1 deletion fastai/text/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def predict(self, text, n_words=1, no_unk=True, temperature=1., min_p=None, no_b
decoder=decode_spec_tokens, only_last_word=False):
"Return `text` and the `n_words` that come after"
self.model.reset()
idxs = idxs_all = self.dls.test_dl([text]).items[0].to(self.dls.device)
idxs = idxs_all = self.dls.test_dl([text], num_workers=0).items[0].to(self.dls.device)
if no_unk: unk_idx = self.dls.vocab.index(UNK)
for _ in (range(n_words) if no_bar else progress_bar(range(n_words), leave=False)):
with self.no_bar(): preds,_ = self.get_preds(dl=[(idxs[None],)])
Expand Down
6 changes: 3 additions & 3 deletions nbs/13a_learner.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -426,11 +426,11 @@
"\n",
" @delegates(GatherPredsCallback.__init__)\n",
" def get_preds(self, ds_idx=1, dl=None, with_input=False, with_decoded=False, with_loss=False, act=None,\n",
" inner=False, reorder=True, cbs=None, **kwargs):\n",
" if dl is None: dl = self.dls[ds_idx].new(shuffled=False, drop_last=False)\n",
" inner=False, reorder=True, cbs=None, n_workers=defaults.cpus, **kwargs):\n",
" if dl is None: dl = self.dls[ds_idx].new(shuffled=False, drop_last=False, num_workers=n_workers)\n",
" if reorder and hasattr(dl, 'get_idxs'):\n",
" idxs = dl.get_idxs()\n",
" dl = dl.new(get_idxs = _ConstantFunc(idxs))\n",
" dl = dl.new(get_idxs = _ConstantFunc(idxs), num_workers=n_workers)\n",
" cb = GatherPredsCallback(with_input=with_input, with_loss=with_loss, **kwargs)\n",
" ctx_mgrs = self.validation_context(cbs=L(cbs)+[cb], inner=inner)\n",
" if with_loss: ctx_mgrs.append(self.loss_not_reduced())\n",
Expand Down
2 changes: 1 addition & 1 deletion nbs/37_text.learner.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@
" decoder=decode_spec_tokens, only_last_word=False):\n",
" \"Return `text` and the `n_words` that come after\"\n",
" self.model.reset()\n",
" idxs = idxs_all = self.dls.test_dl([text]).items[0].to(self.dls.device)\n",
" idxs = idxs_all = self.dls.test_dl([text], num_workers=0).items[0].to(self.dls.device)\n",
" if no_unk: unk_idx = self.dls.vocab.index(UNK)\n",
" for _ in (range(n_words) if no_bar else progress_bar(range(n_words), leave=False)):\n",
" with self.no_bar(): preds,_ = self.get_preds(dl=[(idxs[None],)])\n",
Expand Down
4 changes: 2 additions & 2 deletions nbs/43_tabular.learner.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@
"@log_args(but_as=Learner.__init__)\n",
"class TabularLearner(Learner):\n",
" \"`Learner` for tabular data\"\n",
" def predict(self, row):\n",
" def predict(self, row, n_workers=defaults.cpus):\n",
" \"Predict on a Pandas Series\"\n",
" dl = self.dls.test_dl(row.to_frame().T)\n",
" dl = self.dls.test_dl(row.to_frame().T, num_workers=0)\n",
" dl.dataset.conts = dl.dataset.conts.astype(np.float32)\n",
" inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)\n",
" b = (*tuplify(inp),*tuplify(dec_preds))\n",
Expand Down

0 comments on commit 89af725

Please sign in to comment.