Skip to content

Commit

Permalink
Merge pull request #693 from amansrivastava17/add_probablity_for_mult…
Browse files Browse the repository at this point in the history
…iclass

Add method to find probability for each class in case of multi-class classification
  • Loading branch information
stefan-it committed May 2, 2019
2 parents 5cbc518 + ed49277 commit a9d6b9a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
18 changes: 15 additions & 3 deletions flair/models/text_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,13 @@ def forward_labels_and_loss(
return labels, loss

def predict(
self, sentences: Union[Sentence, List[Sentence]], mini_batch_size: int = 32
self, sentences: Union[Sentence, List[Sentence]], mini_batch_size: int = 32, multi_class_prob: bool = False,
) -> List[Sentence]:
"""
Predicts the class labels for the given sentences. The labels are directly added to the sentences.
:param sentences: list of sentences
:param mini_batch_size: mini batch size to use
:param multi_class_prob : return probability for all class for multiclass
:return: the list of sentences containing the labels
"""
with torch.no_grad():
Expand All @@ -124,7 +125,7 @@ def predict(

for batch in batches:
scores = self.forward(batch)
predicted_labels = self._obtain_labels(scores)
predicted_labels = self._obtain_labels(scores, predict_prob=multi_class_prob)

for (sentence, labels) in zip(batch, predicted_labels):
sentence.labels = labels
Expand Down Expand Up @@ -264,7 +265,7 @@ def _calculate_loss(

return self._calculate_single_label_loss(scores, sentences)

def _obtain_labels(self, scores: List[List[float]]) -> List[List[Label]]:
def _obtain_labels(self, scores: List[List[float]], predict_prob: bool = False) -> List[List[Label]]:
"""
Predicts the labels of sentences.
:param scores: the prediction scores from the model
Expand All @@ -274,6 +275,9 @@ def _obtain_labels(self, scores: List[List[float]]) -> List[List[Label]]:
if self.multi_label:
return [self._get_multi_label(s) for s in scores]

elif predict_prob:
return [self._predict_label_prob(s) for s in scores]

return [self._get_single_label(s) for s in scores]

def _get_multi_label(self, label_scores) -> List[Label]:
Expand All @@ -296,6 +300,14 @@ def _get_single_label(self, label_scores) -> List[Label]:

return [Label(label, conf.item())]

def _predict_label_prob(self, label_scores) -> List[Label]:
softmax = torch.nn.functional.softmax(label_scores, dim=0)
label_probs = []
for idx, conf in enumerate(softmax):
label = self.label_dictionary.get_item_for_index(idx)
label_probs.append(Label(label, conf.item()))
return label_probs

def _calculate_multi_label_loss(
self, label_scores, sentences: List[Sentence]
) -> float:
Expand Down
38 changes: 38 additions & 0 deletions tests/test_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,44 @@ def test_train_load_use_classifier(results_base_path, tasks_base_path):
shutil.rmtree(results_base_path)


@pytest.mark.integration
def test_train_load_use_classifier_with_prob(results_base_path, tasks_base_path):
corpus = NLPTaskDataFetcher.load_corpus("imdb", base_path=tasks_base_path)
label_dict = corpus.make_label_dictionary()

word_embedding: WordEmbeddings = WordEmbeddings("turian")
document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings(
[word_embedding], 128, 1, False, 64, False, False
)

model = TextClassifier(document_embeddings, label_dict, False)

trainer = ModelTrainer(model, corpus)
trainer.train(
results_base_path, EvaluationMetric.MICRO_F1_SCORE, max_epochs=2, test_mode=True
)

sentence = Sentence("Berlin is a really nice city.")

for s in model.predict(sentence, multi_class_prob=True):
for l in s.labels:
assert l.value is not None
assert 0.0 <= l.score <= 1.0
assert type(l.score) is float

loaded_model = TextClassifier.load(results_base_path / "final-model.pt")

sentence = Sentence("I love Berlin")
sentence_empty = Sentence(" ")

loaded_model.predict(sentence, multi_class_prob=True)
loaded_model.predict([sentence, sentence_empty], multi_class_prob=True)
loaded_model.predict([sentence_empty], multi_class_prob=True)

# clean up results directory
shutil.rmtree(results_base_path)


@pytest.mark.integration
def test_train_load_use_classifier_multi_label(results_base_path, tasks_base_path):
# corpus = NLPTaskDataFetcher.load_corpus('multi_class', base_path=tasks_base_path)
Expand Down

0 comments on commit a9d6b9a

Please sign in to comment.