Skip to content

Commit

Permalink
Merge pull request #707 from ashutoshsingh0223/variable_confidence_th…
Browse files Browse the repository at this point in the history
…reshold_multilabel_classfication

capability to change threshold during multi label classification
  • Loading branch information
Alan Akbik committed May 8, 2019
2 parents bed55b4 + c4b7a62 commit db96d22
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion flair/models/text_classification_model.py
Expand Up @@ -33,13 +33,15 @@ def __init__(
document_embeddings: flair.embeddings.DocumentEmbeddings,
label_dictionary: Dictionary,
multi_label: bool,
multi_label_threshold: float = 0.5
):

super(TextClassifier, self).__init__()

self.document_embeddings: flair.embeddings.DocumentRNNEmbeddings = document_embeddings
self.label_dictionary: Dictionary = label_dictionary
self.multi_label = multi_label
self.multi_label_threshold = multi_label_threshold

self.decoder = nn.Linear(
self.document_embeddings.embedding_length, len(self.label_dictionary)
Expand Down Expand Up @@ -287,7 +289,7 @@ def _get_multi_label(self, label_scores) -> List[Label]:

results = list(map(lambda x: sigmoid(x), label_scores))
for idx, conf in enumerate(results):
if conf > 0.5:
if conf > self.multi_label_threshold:
label = self.label_dictionary.get_item_for_index(idx)
labels.append(Label(label, conf.item()))

Expand Down

0 comments on commit db96d22

Please sign in to comment.