Skip to content

Commit

Permalink
GH-48: added word dropout | moved dropouts into new flair.nn module
Browse files Browse the repository at this point in the history
  • Loading branch information
aakbik committed Sep 17, 2018
1 parent 72f82f3 commit 70586f9
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 30 deletions.
53 changes: 23 additions & 30 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import warnings

import torch.autograd as autograd
import torch.nn as nn
import torch.nn
import flair.nn
import torch
import numpy as np

Expand Down Expand Up @@ -58,7 +59,7 @@ def pad_tensors(tensor_list, type_=torch.FloatTensor):
return template, lens_


class SequenceTagger(nn.Module):
class SequenceTagger(torch.nn.Module):

def __init__(self,
hidden_size: int,
Expand All @@ -67,7 +68,8 @@ def __init__(self,
tag_type: str,
use_crf: bool = True,
use_rnn: bool = True,
rnn_layers: int = 1
rnn_layers: int = 1,
use_word_dropout: bool = False,
):

super(SequenceTagger, self).__init__()
Expand All @@ -90,40 +92,42 @@ def __init__(self,
self.nlayers: int = rnn_layers
self.hidden_word = None

# self.dropout = nn.Dropout(0.5)
self.dropout: nn.Module = LockedDropout(0.5)
# dropouts
self.dropout: torch.nn.Module = flair.nn.LockedDropout(0.5)

self.use_word_dropout: bool = use_word_dropout
if self.use_word_dropout:
self.word_dropout = flair.nn.WordDropout(0.05)

rnn_input_dim: int = self.embeddings.embedding_length

self.relearn_embeddings: bool = True

if self.relearn_embeddings:
self.embedding2nn = nn.Linear(rnn_input_dim, rnn_input_dim)
self.embedding2nn = torch.nn.Linear(rnn_input_dim, rnn_input_dim)

# bidirectional LSTM on top of embedding layer
self.rnn_type = 'LSTM'
if self.rnn_type in ['LSTM', 'GRU']:

if self.nlayers == 1:
self.rnn = getattr(nn, self.rnn_type)(rnn_input_dim, hidden_size,
self.rnn = getattr(torch.nn, self.rnn_type)(rnn_input_dim, hidden_size,
num_layers=self.nlayers,
bidirectional=True)
else:
self.rnn = getattr(nn, self.rnn_type)(rnn_input_dim, hidden_size,
self.rnn = getattr(torch.nn, self.rnn_type)(rnn_input_dim, hidden_size,
num_layers=self.nlayers,
dropout=0.5,
bidirectional=True)

self.nonlinearity = nn.Tanh()

# final linear map to tag space
if self.use_rnn:
self.linear = nn.Linear(hidden_size * 2, len(tag_dictionary))
self.linear = torch.nn.Linear(hidden_size * 2, len(tag_dictionary))
else:
self.linear = nn.Linear(self.embeddings.embedding_length, len(tag_dictionary))
self.linear = torch.nn.Linear(self.embeddings.embedding_length, len(tag_dictionary))

if self.use_crf:
self.transitions = nn.Parameter(
self.transitions = torch.nn.Parameter(
torch.randn(self.tagset_size, self.tagset_size))
self.transitions.data[self.tag_dictionary.get_idx_for_item(START_TAG), :] = -10000
self.transitions.data[:, self.tag_dictionary.get_idx_for_item(STOP_TAG)] = -10000
Expand Down Expand Up @@ -220,6 +224,10 @@ def forward(self, sentences: List[Sentence]) -> Tuple[List, List]:
# --------------------------------------------------------------------
sentence_tensor = self.dropout(sentence_tensor)

# use word dropout if set
if self.use_word_dropout:
sentence_tensor = self.word_dropout(sentence_tensor)

if self.relearn_embeddings:
sentence_tensor = self.embedding2nn(sentence_tensor)

Expand Down Expand Up @@ -362,7 +370,7 @@ def neg_log_likelihood(self, sentences: List[Sentence], tag_type: str):
tag_tensor = autograd.Variable(torch.cuda.LongTensor(sentence_tags))
else:
tag_tensor = autograd.Variable(torch.LongTensor(sentence_tags))
score += nn.functional.cross_entropy(sentence_feats, tag_tensor)
score += torch.nn.functional.cross_entropy(sentence_feats, tag_tensor)

return score

Expand Down Expand Up @@ -572,19 +580,4 @@ def load(model: str):

if model_file is not None:
tagger: SequenceTagger = SequenceTagger.load_from_file(model_file)
return tagger


class LockedDropout(nn.Module):
def __init__(self, dropout_rate=0.5):
super(LockedDropout, self).__init__()
self.dropout_rate = dropout_rate

def forward(self, x):
if not self.training or not self.dropout_rate:
return x

m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - self.dropout_rate)
mask = torch.autograd.Variable(m, requires_grad=False) / (1 - self.dropout_rate)
mask = mask.expand_as(x)
return mask * x
return tagger
37 changes: 37 additions & 0 deletions flair/nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch.nn


class LockedDropout(torch.nn.Module):
"""
Implementation of locked (or variational) dropout. Randomly drops out entire parameters in embedding space.
"""
def __init__(self, dropout_rate=0.5):
super(LockedDropout, self).__init__()
self.dropout_rate = dropout_rate

def forward(self, x):
if not self.training or not self.dropout_rate:
return x

m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - self.dropout_rate)
mask = torch.autograd.Variable(m, requires_grad=False) / (1 - self.dropout_rate)
mask = mask.expand_as(x)
return mask * x


class WordDropout(torch.nn.Module):
"""
Implementation of word dropout. Randomly drops out entire words (or characters) in embedding space.
"""
def __init__(self, dropout_rate=0.05):
super(WordDropout, self).__init__()
self.dropout_rate = dropout_rate

def forward(self, x):
if not self.training or not self.dropout_rate:
return x

m = x.data.new(x.size(0), 1, 1).bernoulli_(1 - self.dropout_rate)
mask = torch.autograd.Variable(m, requires_grad=False)
mask = mask.expand_as(x)
return mask * x

0 comments on commit 70586f9

Please sign in to comment.