Skip to content

Commit

Permalink
GH-48: Update eval method.
Browse files Browse the repository at this point in the history
  • Loading branch information
tabergma committed Sep 28, 2018
1 parent 03258f3 commit 6b225f1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 0 additions & 2 deletions flair/models/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ def forward(self, input, hidden, ordered_sequence_lengths=None):
encoded = self.encoder(input)
emb = self.drop(encoded)

self.rnn.flatten_parameters()

output, hidden = self.rnn(emb, hidden)

if self.proj is not None:
Expand Down
4 changes: 3 additions & 1 deletion flair/trainers/text_classification_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def evaluate(self, sentences: List[Sentence], eval_class_metrics: bool = False,
range(0, len(sentences), mini_batch_size)]

y_pred = []
y_true = convert_labels_to_one_hot([sentence.get_label_names() for sentence in sentences], self.label_dict)
y_true = []

for batch in batches:
scores = self.model.forward(batch)
Expand All @@ -218,11 +218,13 @@ def evaluate(self, sentences: List[Sentence], eval_class_metrics: bool = False,

eval_loss += loss

y_true.extend([sentence.get_label_names() for sentence in batch])
y_pred.extend([[label.value for label in sent_labels] for sent_labels in labels])

if not embeddings_in_memory:
clear_embeddings(batch)

y_true = convert_labels_to_one_hot(y_true, self.label_dict)
y_pred = convert_labels_to_one_hot(y_pred, self.label_dict)

metrics = [calculate_micro_avg_metric(y_true, y_pred, self.label_dict)]
Expand Down

0 comments on commit 6b225f1

Please sign in to comment.