Skip to content

Commit

Permalink
GH-48: added word dropout implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
aakbik committed Aug 8, 2018
1 parent 446c183 commit fb7543d
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 9 deletions.
27 changes: 24 additions & 3 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self,
tag_type: str,
use_crf: bool = True,
use_rnn: bool = True,
use_word_dropout: bool = False,
rnn_layers: int = 1
):

Expand All @@ -66,9 +67,12 @@ def __init__(self,
self.nlayers: int = rnn_layers
self.hidden_word = None

# self.dropout = nn.Dropout(0.5)
self.dropout: nn.Module = LockedDropout(0.5)

self.use_word_dropout: bool = use_word_dropout
if self.use_word_dropout:
self.word_dropout = WordDropout(0.2)

rnn_input_dim: int = self.embeddings.embedding_length

self.relearn_embeddings: bool = True
Expand All @@ -90,8 +94,6 @@ def __init__(self,
dropout=0.5,
bidirectional=True)

self.nonlinearity = nn.Tanh()

# final linear map to tag space
if self.use_rnn:
self.linear = nn.Linear(hidden_size * 2, len(tag_dictionary))
Expand Down Expand Up @@ -198,6 +200,10 @@ def forward(self, sentences: List[Sentence]) -> Tuple[List, List]:
# --------------------------------------------------------------------
# FF PART
# --------------------------------------------------------------------
if self.use_word_dropout:
# print(sentence_tensor)
sentence_tensor = self.word_dropout(sentence_tensor)
# print(sentence_tensor)
sentence_tensor = self.dropout(sentence_tensor)

if self.relearn_embeddings:
Expand Down Expand Up @@ -502,3 +508,18 @@ def forward(self, x):
mask = torch.autograd.Variable(m, requires_grad=False) / (1 - self.dropout_rate)
mask = mask.expand_as(x)
return mask * x


class WordDropout(nn.Module):
def __init__(self, dropout_rate=0.5):
super(WordDropout, self).__init__()
self.dropout_rate = dropout_rate

def forward(self, x):
if not self.training or not self.dropout_rate:
return x

m = x.data.new(x.size(0), 1, 1).bernoulli_(1 - self.dropout_rate)
mask = torch.autograd.Variable(m, requires_grad=False)
mask = mask.expand_as(x)
return mask * x
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='flair',
version='0.2.0.post1',
version='0.2.1',
description='A very simple framework for state-of-the-art NLP',
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
Expand Down
15 changes: 10 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,30 @@

# comment in these lines to use contextual string embeddings
#
# CharLMEmbeddings('news-forward'),
# CharLMEmbeddings('news-forward-fast'),
#
# CharLMEmbeddings('news-backward'),
# CharLMEmbeddings('news-backward-fast'),
]

embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types)

# initialize sequence tagger
from flair.models import SequenceTagger

use_word_dropout = True
print(use_word_dropout)
tagger: SequenceTagger = SequenceTagger(hidden_size=256,
embeddings=embeddings,
tag_dictionary=tag_dictionary,
tag_type=tag_type,
use_crf=True)
use_crf=True,
use_word_dropout=use_word_dropout)

print(tagger.modules().__next__())

# initialize trainer
from flair.trainers.sequence_tagger_trainer import SequenceTaggerTrainer

trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=True)
trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=False)

trainer.train('resources/taggers/example-ner', learning_rate=0.1, mini_batch_size=32, max_epochs=20)
trainer.train('resources/taggers/example-ner', learning_rate=0.1, mini_batch_size=32, max_epochs=50)

0 comments on commit fb7543d

Please sign in to comment.