Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialize multi_label_threshold for classification models #2368

Merged
merged 5 commits into from Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions flair/models/pairwise_classification_model.py
Expand Up @@ -132,6 +132,7 @@ def _get_state_dict(self):
"label_dictionary": self.label_dictionary,
"label_type": self.label_type,
"multi_label": self.multi_label,
"multi_label_threshold": self.multi_label_threshold,
"loss_weights": self.loss_weights,
"embed_separately": self.embed_separately,
}
Expand All @@ -145,6 +146,7 @@ def _init_model_with_state_dict(state):
label_dictionary=state["label_dictionary"],
label_type=state["label_type"],
multi_label=state["multi_label"],
multi_label_threshold=0.5 if "multi_label_threshold" not in state.keys() else state["multi_label_threshold"],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of default 0.5 the defaults can be read from the class:

super_defaults = inspect.signature(super(TextPairClassifier, self).__init__)
multi_label_threshold_default = super_defaults.parameters['multi_label_threshold'].default

and then use multi_label_threshold_default instead of 0.5.

loss_weights=state["loss_weights"],
embed_separately=state["embed_separately"],
)
Expand Down
4 changes: 3 additions & 1 deletion flair/models/text_classification_model.py
Expand Up @@ -87,6 +87,7 @@ def _get_state_dict(self):
"label_dictionary": self.label_dictionary,
"label_type": self.label_type,
"multi_label": self.multi_label,
"multi_label_threshold": self.multi_label_threshold,
"weight_dict": self.weight_dict,
}
return model_state
Expand All @@ -101,6 +102,7 @@ def _init_model_with_state_dict(state):
label_dictionary=state["label_dictionary"],
label_type=label_type,
multi_label=state["multi_label"],
multi_label_threshold=0.5 if "multi_label_threshold" not in state.keys() else state["multi_label_threshold"],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of default 0.5 the defaults can be read from the class:

super_defaults = inspect.signature(super(TextClassifier, self).__init__)
multi_label_threshold_default = super_defaults.parameters['multi_label_threshold'].default

and then use multi_label_threshold_default instead of 0.5.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool didn't know about this!

loss_weights=weights,
)
model.load_state_dict(state["state_dict"])
Expand Down Expand Up @@ -140,4 +142,4 @@ def _fetch_model(model_name) -> str:

@property
def label_type(self):
return self._label_type
return self._label_type
30 changes: 14 additions & 16 deletions flair/nn/model.py
Expand Up @@ -361,7 +361,7 @@ def __init__(self,

# set up multi-label logic
self.multi_label = multi_label
self.multi_label_threshold = {'default': multi_label_threshold} if type(multi_label_threshold) is float else multi_label_threshold
self.multi_label_threshold = multi_label_threshold

# loss weights and loss function
self.weight_dict = loss_weights
Expand Down Expand Up @@ -476,17 +476,17 @@ def predict(
if len(label_candidates) > 0:

if self.multi_label or multi_class_prob:
sigmoided = torch.sigmoid(scores) # size: (n_sentences, n_classes)
n_labels = sigmoided.size(1)
for s_idx, (sentence, label) in enumerate(zip(sentences, label_candidates)):
for l_idx in range(n_labels):
label_value = self.label_dictionary.get_item_for_index(l_idx)
if label_value == 'O': continue
label_threshold = self.multi_label_threshold['default'] if label_value not in self.multi_label_threshold else self.multi_label_threshold[label_value]
label_score = sigmoided[s_idx, l_idx].item()
if label_score > label_threshold or multi_class_prob:
label.set_value(value=label_value, score=label_score)
sigmoided = torch.sigmoid(scores)
s_idx = 0
for sentence, label in zip(sentences, label_candidates):
for idx in range(sigmoided.size(1)):
if sigmoided[s_idx, idx] > self.multi_label_threshold or multi_class_prob:
label_value = self.label_dictionary.get_item_for_index(idx)
if label_value == 'O': continue
label.set_value(value=label_value, score=sigmoided[s_idx, idx].item())
sentence.add_complex_label(label_name, copy.deepcopy(label))
s_idx += 1

else:
softmax = torch.nn.functional.softmax(scores, dim=-1)
conf, idx = torch.max(softmax, dim=-1)
Expand Down Expand Up @@ -526,11 +526,9 @@ def _get_multi_label(self, label_scores) -> List[Label]:

results = list(map(lambda x: sigmoid(x), label_scores))
for idx, conf in enumerate(results):
label_value = self.label_dictionary.get_item_for_index(idx)
label_threshold = self.multi_label_threshold['default'] if label_value not in self.multi_label_threshold else self.multi_label_threshold[label_value]
label_score = conf.item()
if label_score > label_threshold:
labels.append(Label(label_value, label_score))
if conf > self.multi_label_threshold:
label = self.label_dictionary.get_item_for_index(idx)
labels.append(Label(label, conf.item()))

return labels

Expand Down