Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialize multi_label_threshold for classification models #2368

Merged
merged 5 commits into from Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions flair/models/pairwise_classification_model.py
Expand Up @@ -132,6 +132,7 @@ def _get_state_dict(self):
"label_dictionary": self.label_dictionary,
"label_type": self.label_type,
"multi_label": self.multi_label,
"multi_label_threshold": self.multi_label_threshold,
"loss_weights": self.loss_weights,
"embed_separately": self.embed_separately,
}
Expand All @@ -145,6 +146,7 @@ def _init_model_with_state_dict(state):
label_dictionary=state["label_dictionary"],
label_type=state["label_type"],
multi_label=state["multi_label"],
multi_label_threshold=0.5 if "multi_label_threshold" not in state.keys() else state["multi_label_threshold"],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of default 0.5 the defaults can be read from the class:

super_defaults = inspect.signature(super(TextPairClassifier, self).__init__)
multi_label_threshold_default = super_defaults.parameters['multi_label_threshold'].default

and then use multi_label_threshold_default instead of 0.5.

loss_weights=state["loss_weights"],
embed_separately=state["embed_separately"],
)
Expand Down
4 changes: 3 additions & 1 deletion flair/models/text_classification_model.py
Expand Up @@ -87,6 +87,7 @@ def _get_state_dict(self):
"label_dictionary": self.label_dictionary,
"label_type": self.label_type,
"multi_label": self.multi_label,
"multi_label_threshold": self.multi_label_threshold,
"weight_dict": self.weight_dict,
}
return model_state
Expand All @@ -101,6 +102,7 @@ def _init_model_with_state_dict(state):
label_dictionary=state["label_dictionary"],
label_type=label_type,
multi_label=state["multi_label"],
multi_label_threshold=0.5 if "multi_label_threshold" not in state.keys() else state["multi_label_threshold"],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of default 0.5 the defaults can be read from the class:

super_defaults = inspect.signature(super(TextClassifier, self).__init__)
multi_label_threshold_default = super_defaults.parameters['multi_label_threshold'].default

and then use multi_label_threshold_default instead of 0.5.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool didn't know about this!

loss_weights=weights,
)
model.load_state_dict(state["state_dict"])
Expand Down Expand Up @@ -140,4 +142,4 @@ def _fetch_model(model_name) -> str:

@property
def label_type(self):
return self._label_type
return self._label_type
2 changes: 1 addition & 1 deletion flair/nn/model.py
Expand Up @@ -502,6 +502,7 @@ def predict(
if label_score > label_threshold or multi_class_prob:
label.set_value(value=label_value, score=label_score)
sentence.add_complex_label(label_name, copy.deepcopy(label))

else:
softmax = torch.nn.functional.softmax(scores, dim=-1)
conf, idx = torch.max(softmax, dim=-1)
Expand Down Expand Up @@ -553,7 +554,6 @@ def _get_multi_label(self, label_scores) -> List[Label]:
label_score = conf.item()
if label_score > label_threshold:
labels.append(Label(label_value, label_score))

return labels

def _get_single_label(self, label_scores) -> List[Label]:
Expand Down