-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
GH-48: added word dropout | moved dropouts into new flair.nn module
- Loading branch information
aakbik
committed
Sep 17, 2018
1 parent
72f82f3
commit 70586f9
Showing
2 changed files
with
60 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import torch.nn | ||
|
||
|
||
class LockedDropout(torch.nn.Module): | ||
""" | ||
Implementation of locked (or variational) dropout. Randomly drops out entire parameters in embedding space. | ||
""" | ||
def __init__(self, dropout_rate=0.5): | ||
super(LockedDropout, self).__init__() | ||
self.dropout_rate = dropout_rate | ||
|
||
def forward(self, x): | ||
if not self.training or not self.dropout_rate: | ||
return x | ||
|
||
m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - self.dropout_rate) | ||
mask = torch.autograd.Variable(m, requires_grad=False) / (1 - self.dropout_rate) | ||
mask = mask.expand_as(x) | ||
return mask * x | ||
|
||
|
||
class WordDropout(torch.nn.Module): | ||
""" | ||
Implementation of word dropout. Randomly drops out entire words (or characters) in embedding space. | ||
""" | ||
def __init__(self, dropout_rate=0.05): | ||
super(WordDropout, self).__init__() | ||
self.dropout_rate = dropout_rate | ||
|
||
def forward(self, x): | ||
if not self.training or not self.dropout_rate: | ||
return x | ||
|
||
m = x.data.new(x.size(0), 1, 1).bernoulli_(1 - self.dropout_rate) | ||
mask = torch.autograd.Variable(m, requires_grad=False) | ||
mask = mask.expand_as(x) | ||
return mask * x |