Skip to content

Commit

Permalink
Merge pull request #3912 from MSciesiek/fix_casting_types_for_mps
Browse files Browse the repository at this point in the history
Fix casting types for mps
  • Loading branch information
jph00 committed May 26, 2023
2 parents 7176ffb + 8476639 commit b4d8ac0
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion fastai/vision/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,8 @@ def mask_tensor(
if p==1.: return x
if batch: return x if random.random() < p else x.new_zeros(*x.size()) + neutral
if neutral != 0: x.add_(-neutral)
mask = x.new_empty(*x.size()).bernoulli_(p)
# Extra casting to float and long to prevent crashes on mps accelerator (issue #3911)
mask = x.new_empty(*x.size()).float().bernoulli_(p).long()
x.mul_(mask)
return x.add_(neutral) if neutral != 0 else x

Expand Down
3 changes: 2 additions & 1 deletion nbs/09_vision.augment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2073,7 +2073,8 @@
" if p==1.: return x\n",
" if batch: return x if random.random() < p else x.new_zeros(*x.size()) + neutral\n",
" if neutral != 0: x.add_(-neutral)\n",
" mask = x.new_empty(*x.size()).bernoulli_(p)\n",
" # Extra casting to float and long to prevent crashes on mps accelerator (issue #3911)\n",
" mask = x.new_empty(*x.size()).float().bernoulli_(p).long()\n",
" x.mul_(mask)\n",
" return x.add_(neutral) if neutral != 0 else x"
]
Expand Down

0 comments on commit b4d8ac0

Please sign in to comment.