Skip to content

Commit

Permalink
Merge pull request #126 from zalandoresearch/GH-48-word-dropout
Browse files Browse the repository at this point in the history
GH-48: Don't use Locked and Word Dropout in text classifier per default
  • Loading branch information
Alan Akbik committed Sep 28, 2018
2 parents ac0b922 + 6b225f1 commit ce1778c
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
2 changes: 1 addition & 1 deletion flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ def _print_statistics_for(sentences, name):

size_dict = {}
for l, c in classes_to_count.items():
size_dict = { l: c }
size_dict[l] = c
size_dict['total'] = len(sentences)

stats = {
Expand Down
8 changes: 6 additions & 2 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ class DocumentLSTMEmbeddings(DocumentEmbeddings):

def __init__(self, token_embeddings: List[TokenEmbeddings], hidden_states=128, num_layers=1,
reproject_words: bool = True, reproject_words_dimension: int = None, bidirectional: bool = False,
use_first_representation: bool = False, use_word_dropout: bool = True):
use_first_representation: bool = False, use_word_dropout: bool = False, use_locked_dropout: bool = False):
"""The constructor takes a list of embeddings to be combined.
:param token_embeddings: a list of token embeddings
:param hidden_states: the number of hidden states in the lstm
Expand All @@ -553,6 +553,7 @@ def __init__(self, token_embeddings: List[TokenEmbeddings], hidden_states=128, n
:param use_first_representation: boolean value, indicating whether to concatenate the first and last
representation of the lstm to be used as final document embedding.
:param use_word_dropout: boolean value, indicating whether to use word dropout or not.
:param use_locked_dropout: boolean value, indicating whether to use locked dropout or not.
"""
super().__init__()

Expand Down Expand Up @@ -586,7 +587,10 @@ def __init__(self, token_embeddings: List[TokenEmbeddings], hidden_states=128, n
bidirectional=self.bidirectional)

# dropouts
self.dropout: torch.nn.Module = LockedDropout(0.5)
if use_locked_dropout:
self.dropout: torch.nn.Module = LockedDropout(0.5)
else:
self.dropout = torch.nn.Dropout(0.5)

self.use_word_dropout: bool = use_word_dropout
if self.use_word_dropout:
Expand Down
15 changes: 10 additions & 5 deletions flair/trainers/text_classification_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def train(self,
log.info("EPOCH {0}: lr {1:.4f} - bad epochs {2}".format(epoch + 1, learning_rate, scheduler.num_bad_epochs))

dev_metric = train_metric = None
dev_loss = train_loss = '_'
dev_loss = '_'
train_loss = current_loss

if eval_on_train:
train_metric, train_loss = self._calculate_evaluation_results_for(
Expand Down Expand Up @@ -208,19 +209,23 @@ def evaluate(self, sentences: List[Sentence], eval_class_metrics: bool = False,
range(0, len(sentences), mini_batch_size)]

y_pred = []
y_true = convert_labels_to_one_hot([sentence.get_label_names() for sentence in sentences], self.label_dict)
y_true = []

for batch in batches:
scores = self.model.forward(batch)
labels = self.model.obtain_labels(scores)
loss = self.model.calculate_loss(scores, batch)

eval_loss += loss

y_true.extend([sentence.get_label_names() for sentence in batch])
y_pred.extend([[label.value for label in sent_labels] for sent_labels in labels])

if not embeddings_in_memory:
clear_embeddings(batch)

eval_loss += loss

y_pred.extend(convert_labels_to_one_hot([[label.value for label in sent_labels] for sent_labels in labels], self.label_dict))
y_true = convert_labels_to_one_hot(y_true, self.label_dict)
y_pred = convert_labels_to_one_hot(y_pred, self.label_dict)

metrics = [calculate_micro_avg_metric(y_true, y_pred, self.label_dict)]
if eval_class_metrics:
Expand Down

0 comments on commit ce1778c

Please sign in to comment.